This is an automated email from the ASF dual-hosted git repository.

jiangxb1987 pushed a commit to branch branch-3.0
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/branch-3.0 by this push:
     new eb37aa5  Revert "[SPARK-30667][CORE] Add allGather method to 
BarrierTaskContext"
eb37aa5 is described below

commit eb37aa5595badd79becf4d3d332404cbcdb1b12d
Author: Xingbo Jiang <xingbo.ji...@databricks.com>
AuthorDate: Thu Feb 13 17:48:19 2020 -0800

    Revert "[SPARK-30667][CORE] Add allGather method to BarrierTaskContext"
    
    This reverts commit 6001866cea1216da421c5acd71d6fc74228222ac.
---
 .../org/apache/spark/BarrierCoordinator.scala      | 113 ++-------------
 .../org/apache/spark/BarrierTaskContext.scala      | 153 +++++++--------------
 .../org/apache/spark/api/python/PythonRunner.scala |  51 ++-----
 .../spark/scheduler/BarrierTaskContextSuite.scala  |  74 ----------
 python/pyspark/taskcontext.py                      |  49 +------
 python/pyspark/tests/test_taskcontext.py           |  20 ---
 6 files changed, 79 insertions(+), 381 deletions(-)

diff --git a/core/src/main/scala/org/apache/spark/BarrierCoordinator.scala 
b/core/src/main/scala/org/apache/spark/BarrierCoordinator.scala
index 042a266..4e41767 100644
--- a/core/src/main/scala/org/apache/spark/BarrierCoordinator.scala
+++ b/core/src/main/scala/org/apache/spark/BarrierCoordinator.scala
@@ -17,17 +17,12 @@
 
 package org.apache.spark
 
-import java.nio.charset.StandardCharsets.UTF_8
 import java.util.{Timer, TimerTask}
 import java.util.concurrent.ConcurrentHashMap
 import java.util.function.Consumer
 
 import scala.collection.mutable.ArrayBuffer
 
-import org.json4s.JsonAST._
-import org.json4s.JsonDSL._
-import org.json4s.jackson.JsonMethods.{compact, render}
-
 import org.apache.spark.internal.Logging
 import org.apache.spark.rpc.{RpcCallContext, RpcEnv, ThreadSafeRpcEndpoint}
 import org.apache.spark.scheduler.{LiveListenerBus, SparkListener, 
SparkListenerStageCompleted}
@@ -104,15 +99,10 @@ private[spark] class BarrierCoordinator(
     // reset when a barrier() call fails due to timeout.
     private var barrierEpoch: Int = 0
 
-    // An Array of RPCCallContexts for barrier tasks that have made a blocking 
runBarrier() call
+    // An array of RPCCallContexts for barrier tasks that are waiting for 
reply of a barrier()
+    // call.
     private val requesters: ArrayBuffer[RpcCallContext] = new 
ArrayBuffer[RpcCallContext](numTasks)
 
-    // An Array of allGather messages for barrier tasks that have made a 
blocking runBarrier() call
-    private val allGatherMessages: ArrayBuffer[String] = new 
Array[String](numTasks).to[ArrayBuffer]
-
-    // The blocking requestMethod called by tasks to sync up for this stage 
attempt
-    private var requestMethodToSync: RequestMethod.Value = 
RequestMethod.BARRIER
-
     // A timer task that ensures we may timeout for a barrier() call.
     private var timerTask: TimerTask = null
 
@@ -140,32 +130,9 @@ private[spark] class BarrierCoordinator(
 
     // Process the global sync request. The barrier() call succeed if 
collected enough requests
     // within a configured time, otherwise fail all the pending requests.
-    def handleRequest(
-      requester: RpcCallContext,
-      request: RequestToSync
-    ): Unit = synchronized {
+    def handleRequest(requester: RpcCallContext, request: RequestToSync): Unit 
= synchronized {
       val taskId = request.taskAttemptId
       val epoch = request.barrierEpoch
-      val requestMethod = request.requestMethod
-      val partitionId = request.partitionId
-      val allGatherMessage = request match {
-        case ag: AllGatherRequestToSync => ag.allGatherMessage
-        case _ => ""
-      }
-
-      if (requesters.size == 0) {
-        requestMethodToSync = requestMethod
-      }
-
-      if (requestMethodToSync != requestMethod) {
-        requesters.foreach(
-          _.sendFailure(new SparkException(s"$barrierId tried to use 
requestMethod " +
-            s"`$requestMethod` during barrier epoch $barrierEpoch, which does 
not match " +
-            s"the current synchronized requestMethod `$requestMethodToSync`"
-          ))
-        )
-        cleanupBarrierStage(barrierId)
-      }
 
       // Require the number of tasks is correctly set from the 
BarrierTaskContext.
       require(request.numTasks == numTasks, s"Number of tasks of $barrierId is 
" +
@@ -186,7 +153,6 @@ private[spark] class BarrierCoordinator(
         }
         // Add the requester to array of RPCCallContexts pending for reply.
         requesters += requester
-        allGatherMessages(partitionId) = allGatherMessage
         logInfo(s"Barrier sync epoch $barrierEpoch from $barrierId received 
update from Task " +
           s"$taskId, current progress: ${requesters.size}/$numTasks.")
         if (maybeFinishAllRequesters(requesters, numTasks)) {
@@ -196,7 +162,6 @@ private[spark] class BarrierCoordinator(
             s"tasks, finished successfully.")
           barrierEpoch += 1
           requesters.clear()
-          allGatherMessages.clear()
           cancelTimerTask()
         }
       }
@@ -208,13 +173,7 @@ private[spark] class BarrierCoordinator(
         requesters: ArrayBuffer[RpcCallContext],
         numTasks: Int): Boolean = {
       if (requesters.size == numTasks) {
-        requestMethodToSync match {
-          case RequestMethod.BARRIER =>
-            requesters.foreach(_.reply(""))
-          case RequestMethod.ALL_GATHER =>
-            val json: String = compact(render(allGatherMessages))
-            requesters.foreach(_.reply(json))
-        }
+        requesters.foreach(_.reply(()))
         true
       } else {
         false
@@ -227,7 +186,6 @@ private[spark] class BarrierCoordinator(
       // messages come from current stage attempt shall fail.
       barrierEpoch = -1
       requesters.clear()
-      allGatherMessages.clear()
       cancelTimerTask()
     }
   }
@@ -241,11 +199,11 @@ private[spark] class BarrierCoordinator(
   }
 
   override def receiveAndReply(context: RpcCallContext): PartialFunction[Any, 
Unit] = {
-    case request: RequestToSync =>
+    case request @ RequestToSync(numTasks, stageId, stageAttemptId, _, _) =>
       // Get or init the ContextBarrierState correspond to the stage attempt.
-      val barrierId = ContextBarrierId(request.stageId, request.stageAttemptId)
+      val barrierId = ContextBarrierId(stageId, stageAttemptId)
       states.computeIfAbsent(barrierId,
-        (key: ContextBarrierId) => new ContextBarrierState(key, 
request.numTasks))
+        (key: ContextBarrierId) => new ContextBarrierState(key, numTasks))
       val barrierState = states.get(barrierId)
 
       barrierState.handleRequest(context, request)
@@ -258,16 +216,6 @@ private[spark] class BarrierCoordinator(
 
 private[spark] sealed trait BarrierCoordinatorMessage extends Serializable
 
-private[spark] sealed trait RequestToSync extends BarrierCoordinatorMessage {
-  def numTasks: Int
-  def stageId: Int
-  def stageAttemptId: Int
-  def taskAttemptId: Long
-  def barrierEpoch: Int
-  def partitionId: Int
-  def requestMethod: RequestMethod.Value
-}
-
 /**
  * A global sync request message from BarrierTaskContext, by `barrier()` call. 
Each request is
  * identified by stageId + stageAttemptId + barrierEpoch.
@@ -276,44 +224,11 @@ private[spark] sealed trait RequestToSync extends 
BarrierCoordinatorMessage {
  * @param stageId ID of current stage
  * @param stageAttemptId ID of current stage attempt
  * @param taskAttemptId Unique ID of current task
- * @param barrierEpoch ID of the `barrier()` call, a task may consist multiple 
`barrier()` calls
- * @param partitionId ID of the current partition the task is assigned to
- * @param requestMethod The BarrierTaskContext method that was called to 
trigger BarrierCoordinator
+ * @param barrierEpoch ID of the `barrier()` call, a task may consist multiple 
`barrier()` calls.
  */
-private[spark] case class BarrierRequestToSync(
-  numTasks: Int,
-  stageId: Int,
-  stageAttemptId: Int,
-  taskAttemptId: Long,
-  barrierEpoch: Int,
-  partitionId: Int,
-  requestMethod: RequestMethod.Value
-) extends RequestToSync
-
-/**
- * A global sync request message from BarrierTaskContext, by `allGather()` 
call. Each request is
- * identified by stageId + stageAttemptId + barrierEpoch.
- *
- * @param numTasks The number of global sync requests the BarrierCoordinator 
shall receive
- * @param stageId ID of current stage
- * @param stageAttemptId ID of current stage attempt
- * @param taskAttemptId Unique ID of current task
- * @param barrierEpoch ID of the `barrier()` call, a task may consist multiple 
`barrier()` calls
- * @param partitionId ID of the current partition the task is assigned to
- * @param requestMethod The BarrierTaskContext method that was called to 
trigger BarrierCoordinator
- * @param allGatherMessage Message sent from the BarrierTaskContext if 
requestMethod is ALL_GATHER
- */
-private[spark] case class AllGatherRequestToSync(
-  numTasks: Int,
-  stageId: Int,
-  stageAttemptId: Int,
-  taskAttemptId: Long,
-  barrierEpoch: Int,
-  partitionId: Int,
-  requestMethod: RequestMethod.Value,
-  allGatherMessage: String
-) extends RequestToSync
-
-private[spark] object RequestMethod extends Enumeration {
-  val BARRIER, ALL_GATHER = Value
-}
+private[spark] case class RequestToSync(
+    numTasks: Int,
+    stageId: Int,
+    stageAttemptId: Int,
+    taskAttemptId: Long,
+    barrierEpoch: Int) extends BarrierCoordinatorMessage
diff --git a/core/src/main/scala/org/apache/spark/BarrierTaskContext.scala 
b/core/src/main/scala/org/apache/spark/BarrierTaskContext.scala
index 2263538..3d36980 100644
--- a/core/src/main/scala/org/apache/spark/BarrierTaskContext.scala
+++ b/core/src/main/scala/org/apache/spark/BarrierTaskContext.scala
@@ -17,19 +17,11 @@
 
 package org.apache.spark
 
-import java.nio.charset.StandardCharsets.UTF_8
 import java.util.{Properties, Timer, TimerTask}
 
 import scala.collection.JavaConverters._
-import scala.collection.mutable.ArrayBuffer
 import scala.concurrent.TimeoutException
 import scala.concurrent.duration._
-import scala.language.postfixOps
-
-import org.json4s.DefaultFormats
-import org.json4s.JsonAST._
-import org.json4s.JsonDSL._
-import org.json4s.jackson.JsonMethods.parse
 
 import org.apache.spark.annotation.{Experimental, Since}
 import org.apache.spark.executor.TaskMetrics
@@ -67,31 +59,49 @@ class BarrierTaskContext private[spark] (
   // from different tasks within the same barrier stage attempt to succeed.
   private lazy val numTasks = getTaskInfos().size
 
-  private def getRequestToSync(
-    numTasks: Int,
-    stageId: Int,
-    stageAttemptNumber: Int,
-    taskAttemptId: Long,
-    barrierEpoch: Int,
-    partitionId: Int,
-    requestMethod: RequestMethod.Value,
-    allGatherMessage: String
-  ): RequestToSync = {
-    requestMethod match {
-      case RequestMethod.BARRIER =>
-        BarrierRequestToSync(numTasks, stageId, stageAttemptNumber, 
taskAttemptId,
-          barrierEpoch, partitionId, requestMethod)
-      case RequestMethod.ALL_GATHER =>
-        AllGatherRequestToSync(numTasks, stageId, stageAttemptNumber, 
taskAttemptId,
-          barrierEpoch, partitionId, requestMethod, allGatherMessage)
-    }
-  }
-
-  private def runBarrier(
-    requestMethod: RequestMethod.Value,
-    allGatherMessage: String = ""
-  ): String = {
-
+  /**
+   * :: Experimental ::
+   * Sets a global barrier and waits until all tasks in this stage hit this 
barrier. Similar to
+   * MPI_Barrier function in MPI, the barrier() function call blocks until all 
tasks in the same
+   * stage have reached this routine.
+   *
+   * CAUTION! In a barrier stage, each task must have the same number of 
barrier() calls, in all
+   * possible code branches. Otherwise, you may get the job hanging or a 
SparkException after
+   * timeout. Some examples of '''misuses''' are listed below:
+   * 1. Only call barrier() function on a subset of all the tasks in the same 
barrier stage, it
+   * shall lead to timeout of the function call.
+   * {{{
+   *   rdd.barrier().mapPartitions { iter =>
+   *       val context = BarrierTaskContext.get()
+   *       if (context.partitionId() == 0) {
+   *           // Do nothing.
+   *       } else {
+   *           context.barrier()
+   *       }
+   *       iter
+   *   }
+   * }}}
+   *
+   * 2. Include barrier() function in a try-catch code block, this may lead to 
timeout of the
+   * second function call.
+   * {{{
+   *   rdd.barrier().mapPartitions { iter =>
+   *       val context = BarrierTaskContext.get()
+   *       try {
+   *           // Do something that might throw an Exception.
+   *           doSomething()
+   *           context.barrier()
+   *       } catch {
+   *           case e: Exception => logWarning("...", e)
+   *       }
+   *       context.barrier()
+   *       iter
+   *   }
+   * }}}
+   */
+  @Experimental
+  @Since("2.4.0")
+  def barrier(): Unit = {
     logInfo(s"Task $taskAttemptId from Stage $stageId(Attempt 
$stageAttemptNumber) has entered " +
       s"the global sync, current barrier epoch is $barrierEpoch.")
     logTrace("Current callSite: " + Utils.getCallSite())
@@ -108,12 +118,10 @@ class BarrierTaskContext private[spark] (
     // Log the update of global sync every 60 seconds.
     timer.schedule(timerTask, 60000, 60000)
 
-    var json: String = ""
-
     try {
-      val abortableRpcFuture = barrierCoordinator.askAbortable[String](
-        message = getRequestToSync(numTasks, stageId, stageAttemptNumber,
-          taskAttemptId, barrierEpoch, partitionId, requestMethod, 
allGatherMessage),
+      val abortableRpcFuture = barrierCoordinator.askAbortable[Unit](
+        message = RequestToSync(numTasks, stageId, stageAttemptNumber, 
taskAttemptId,
+          barrierEpoch),
         // Set a fixed timeout for RPC here, so users shall get a 
SparkException thrown by
         // BarrierCoordinator on timeout, instead of RPCTimeoutException from 
the RPC framework.
         timeout = new RpcTimeout(365.days, "barrierTimeout"))
@@ -125,7 +133,7 @@ class BarrierTaskContext private[spark] (
         while (!abortableRpcFuture.toFuture.isCompleted) {
           // wait RPC future for at most 1 second
           try {
-            json = ThreadUtils.awaitResult(abortableRpcFuture.toFuture, 
1.second)
+            ThreadUtils.awaitResult(abortableRpcFuture.toFuture, 1.second)
           } catch {
             case _: TimeoutException | _: InterruptedException =>
               // If `TimeoutException` thrown, waiting RPC future reach 1 
second.
@@ -155,73 +163,6 @@ class BarrierTaskContext private[spark] (
       timerTask.cancel()
       timer.purge()
     }
-    json
-  }
-
-  /**
-   * :: Experimental ::
-   * Sets a global barrier and waits until all tasks in this stage hit this 
barrier. Similar to
-   * MPI_Barrier function in MPI, the barrier() function call blocks until all 
tasks in the same
-   * stage have reached this routine.
-   *
-   * CAUTION! In a barrier stage, each task must have the same number of 
barrier() calls, in all
-   * possible code branches. Otherwise, you may get the job hanging or a 
SparkException after
-   * timeout. Some examples of '''misuses''' are listed below:
-   * 1. Only call barrier() function on a subset of all the tasks in the same 
barrier stage, it
-   * shall lead to timeout of the function call.
-   * {{{
-   *   rdd.barrier().mapPartitions { iter =>
-   *       val context = BarrierTaskContext.get()
-   *       if (context.partitionId() == 0) {
-   *           // Do nothing.
-   *       } else {
-   *           context.barrier()
-   *       }
-   *       iter
-   *   }
-   * }}}
-   *
-   * 2. Include barrier() function in a try-catch code block, this may lead to 
timeout of the
-   * second function call.
-   * {{{
-   *   rdd.barrier().mapPartitions { iter =>
-   *       val context = BarrierTaskContext.get()
-   *       try {
-   *           // Do something that might throw an Exception.
-   *           doSomething()
-   *           context.barrier()
-   *       } catch {
-   *           case e: Exception => logWarning("...", e)
-   *       }
-   *       context.barrier()
-   *       iter
-   *   }
-   * }}}
-   */
-  @Experimental
-  @Since("2.4.0")
-  def barrier(): Unit = {
-    runBarrier(RequestMethod.BARRIER)
-    ()
-  }
-
-  /**
-   * :: Experimental ::
-   * Blocks until all tasks in the same stage have reached this routine. Each 
task passes in
-   * a message and returns with a list of all the messages passed in by each 
of those tasks.
-   *
-   * CAUTION! The allGather method requires the same precautions as the 
barrier method
-   *
-   * The message is type String rather than Array[Byte] because it is more 
convenient for
-   * the user at the cost of worse performance.
-   */
-  @Experimental
-  @Since("3.0.0")
-  def allGather(message: String): ArrayBuffer[String] = {
-    val json = runBarrier(RequestMethod.ALL_GATHER, message)
-    val jsonArray = parse(json)
-    implicit val formats = DefaultFormats
-    ArrayBuffer(jsonArray.extract[Array[String]]: _*)
   }
 
   /**
diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala 
b/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala
index fa8bf0f..658e0d5 100644
--- a/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala
+++ b/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala
@@ -24,13 +24,8 @@ import java.nio.charset.StandardCharsets.UTF_8
 import java.util.concurrent.atomic.AtomicBoolean
 
 import scala.collection.JavaConverters._
-import scala.collection.mutable.ArrayBuffer
 import scala.util.control.NonFatal
 
-import org.json4s.JsonAST._
-import org.json4s.JsonDSL._
-import org.json4s.jackson.JsonMethods.{compact, render}
-
 import org.apache.spark._
 import org.apache.spark.internal.Logging
 import org.apache.spark.internal.config.{BUFFER_SIZE, EXECUTOR_CORES}
@@ -243,18 +238,13 @@ private[spark] abstract class BasePythonRunner[IN, OUT](
                   sock.setSoTimeout(10000)
                   authHelper.authClient(sock)
                   val input = new DataInputStream(sock.getInputStream())
-                  val requestMethod = input.readInt()
-                  // The BarrierTaskContext function may wait infinitely, 
socket shall not timeout
-                  // before the function finishes.
-                  sock.setSoTimeout(0)
-                  requestMethod match {
+                  input.readInt() match {
                     case BarrierTaskContextMessageProtocol.BARRIER_FUNCTION =>
-                      barrierAndServe(requestMethod, sock)
-                    case BarrierTaskContextMessageProtocol.ALL_GATHER_FUNCTION 
=>
-                      val length = input.readInt()
-                      val message = new Array[Byte](length)
-                      input.readFully(message)
-                      barrierAndServe(requestMethod, sock, new String(message, 
UTF_8))
+                      // The barrier() function may wait infinitely, socket 
shall not timeout
+                      // before the function finishes.
+                      sock.setSoTimeout(0)
+                      barrierAndServe(sock)
+
                     case _ =>
                       val out = new DataOutputStream(new BufferedOutputStream(
                         sock.getOutputStream))
@@ -405,31 +395,15 @@ private[spark] abstract class BasePythonRunner[IN, OUT](
     }
 
     /**
-     * Gateway to call BarrierTaskContext methods.
+     * Gateway to call BarrierTaskContext.barrier().
      */
-    def barrierAndServe(requestMethod: Int, sock: Socket, message: String = 
""): Unit = {
-      require(
-        serverSocket.isDefined,
-        "No available ServerSocket to redirect the BarrierTaskContext method 
call."
-      )
+    def barrierAndServe(sock: Socket): Unit = {
+      require(serverSocket.isDefined, "No available ServerSocket to redirect 
the barrier() call.")
+
       val out = new DataOutputStream(new 
BufferedOutputStream(sock.getOutputStream))
       try {
-        var result: String = ""
-        requestMethod match {
-          case BarrierTaskContextMessageProtocol.BARRIER_FUNCTION =>
-            context.asInstanceOf[BarrierTaskContext].barrier()
-            result = BarrierTaskContextMessageProtocol.BARRIER_RESULT_SUCCESS
-          case BarrierTaskContextMessageProtocol.ALL_GATHER_FUNCTION =>
-            val messages: ArrayBuffer[String] = 
context.asInstanceOf[BarrierTaskContext].allGather(
-              message
-            )
-            result = compact(render(JArray(
-              messages.map(
-                (message) => JString(message)
-              ).toList
-            )))
-        }
-        writeUTF(result, out)
+        context.asInstanceOf[BarrierTaskContext].barrier()
+        writeUTF(BarrierTaskContextMessageProtocol.BARRIER_RESULT_SUCCESS, out)
       } catch {
         case e: SparkException =>
           writeUTF(e.getMessage, out)
@@ -664,7 +638,6 @@ private[spark] object SpecialLengths {
 
 private[spark] object BarrierTaskContextMessageProtocol {
   val BARRIER_FUNCTION = 1
-  val ALL_GATHER_FUNCTION = 2
   val BARRIER_RESULT_SUCCESS = "success"
   val ERROR_UNRECOGNIZED_FUNCTION = "Not recognized function call from python 
side."
 }
diff --git 
a/core/src/test/scala/org/apache/spark/scheduler/BarrierTaskContextSuite.scala 
b/core/src/test/scala/org/apache/spark/scheduler/BarrierTaskContextSuite.scala
index ed38b7f..fc8ac38 100644
--- 
a/core/src/test/scala/org/apache/spark/scheduler/BarrierTaskContextSuite.scala
+++ 
b/core/src/test/scala/org/apache/spark/scheduler/BarrierTaskContextSuite.scala
@@ -19,7 +19,6 @@ package org.apache.spark.scheduler
 
 import java.io.File
 
-import scala.collection.mutable.ArrayBuffer
 import scala.util.Random
 
 import org.apache.spark._
@@ -53,79 +52,6 @@ class BarrierTaskContextSuite extends SparkFunSuite with 
LocalSparkContext {
     assert(times.max - times.min <= 1000)
   }
 
-  test("share messages with allGather() call") {
-    val conf = new SparkConf()
-      .setMaster("local-cluster[4, 1, 1024]")
-      .setAppName("test-cluster")
-    sc = new SparkContext(conf)
-    val rdd = sc.makeRDD(1 to 10, 4)
-    val rdd2 = rdd.barrier().mapPartitions { it =>
-      val context = BarrierTaskContext.get()
-      // Sleep for a random time before global sync.
-      Thread.sleep(Random.nextInt(1000))
-      // Pass partitionId message in
-      val message = context.partitionId().toString
-      val messages = context.allGather(message)
-      messages.toList.iterator
-    }
-    // Take a sorted list of all the partitionId messages
-    val messages = rdd2.collect().head
-    // All the task partitionIds are shared
-    for((x, i) <- messages.view.zipWithIndex) assert(x == i.toString)
-  }
-
-  test("throw exception if we attempt to synchronize with different blocking 
calls") {
-    val conf = new SparkConf()
-      .setMaster("local-cluster[4, 1, 1024]")
-      .setAppName("test-cluster")
-    sc = new SparkContext(conf)
-    val rdd = sc.makeRDD(1 to 10, 4)
-    val rdd2 = rdd.barrier().mapPartitions { it =>
-      val context = BarrierTaskContext.get()
-      val partitionId = context.partitionId
-      if (partitionId == 0) {
-        context.barrier()
-      } else {
-        context.allGather(partitionId.toString)
-      }
-      Seq(null).iterator
-    }
-    val error = intercept[SparkException] {
-      rdd2.collect()
-    }.getMessage
-    assert(error.contains("does not match the current synchronized 
requestMethod"))
-  }
-
-  test("successively sync with allGather and barrier") {
-    val conf = new SparkConf()
-      .setMaster("local-cluster[4, 1, 1024]")
-      .setAppName("test-cluster")
-    sc = new SparkContext(conf)
-    val rdd = sc.makeRDD(1 to 10, 4)
-    val rdd2 = rdd.barrier().mapPartitions { it =>
-      val context = BarrierTaskContext.get()
-      // Sleep for a random time before global sync.
-      Thread.sleep(Random.nextInt(1000))
-      context.barrier()
-      val time1 = System.currentTimeMillis()
-      // Sleep for a random time before global sync.
-      Thread.sleep(Random.nextInt(1000))
-      // Pass partitionId message in
-      val message = context.partitionId().toString
-      val messages = context.allGather(message)
-      val time2 = System.currentTimeMillis()
-      Seq((time1, time2)).iterator
-    }
-    val times = rdd2.collect()
-    // All the tasks shall finish the first round of global sync within a 
short time slot.
-    val times1 = times.map(_._1)
-    assert(times1.max - times1.min <= 1000)
-
-    // All the tasks shall finish the second round of global sync within a 
short time slot.
-    val times2 = times.map(_._2)
-    assert(times2.max - times2.min <= 1000)
-  }
-
   test("support multiple barrier() call within a single task") {
     initLocalClusterSparkContext()
     val rdd = sc.makeRDD(1 to 10, 4)
diff --git a/python/pyspark/taskcontext.py b/python/pyspark/taskcontext.py
index 90bd234..d648f63 100644
--- a/python/pyspark/taskcontext.py
+++ b/python/pyspark/taskcontext.py
@@ -16,10 +16,9 @@
 #
 
 from __future__ import print_function
-import json
 
 from pyspark.java_gateway import local_connect_and_auth
-from pyspark.serializers import write_int, write_with_length, UTF8Deserializer
+from pyspark.serializers import write_int, UTF8Deserializer
 
 
 class TaskContext(object):
@@ -108,28 +107,18 @@ class TaskContext(object):
 
 
 BARRIER_FUNCTION = 1
-ALL_GATHER_FUNCTION = 2
 
 
-def _load_from_socket(port, auth_secret, function, all_gather_message=None):
+def _load_from_socket(port, auth_secret):
     """
     Load data from a given socket, this is a blocking method thus only return 
when the socket
     connection has been closed.
     """
     (sockfile, sock) = local_connect_and_auth(port, auth_secret)
-
-    # The call may block forever, so no timeout
+    # The barrier() call may block forever, so no timeout
     sock.settimeout(None)
-
-    if function == BARRIER_FUNCTION:
-        # Make a barrier() function call.
-        write_int(function, sockfile)
-    elif function == ALL_GATHER_FUNCTION:
-        # Make a all_gather() function call.
-        write_int(function, sockfile)
-        write_with_length(all_gather_message.encode("utf-8"), sockfile)
-    else:
-        raise ValueError("Unrecognized function type")
+    # Make a barrier() function call.
+    write_int(BARRIER_FUNCTION, sockfile)
     sockfile.flush()
 
     # Collect result.
@@ -210,33 +199,7 @@ class BarrierTaskContext(TaskContext):
             raise Exception("Not supported to call barrier() before initialize 
" +
                             "BarrierTaskContext.")
         else:
-            _load_from_socket(self._port, self._secret, BARRIER_FUNCTION)
-
-    def allGather(self, message=""):
-        """
-        .. note:: Experimental
-
-        This function blocks until all tasks in the same stage have reached 
this routine.
-        Each task passes in a message and returns with a list of all the 
messages passed in
-        by each of those tasks.
-
-        .. warning:: In a barrier stage, each task much have the same number 
of `allGather()`
-            calls, in all possible code branches.
-            Otherwise, you may get the job hanging or a SparkException after 
timeout.
-        """
-        if not isinstance(message, str):
-            raise ValueError("Argument `message` must be of type `str`")
-        elif self._port is None or self._secret is None:
-            raise Exception("Not supported to call barrier() before initialize 
" +
-                            "BarrierTaskContext.")
-        else:
-            gathered_items = _load_from_socket(
-                self._port,
-                self._secret,
-                ALL_GATHER_FUNCTION,
-                message,
-            )
-            return [e for e in json.loads(gathered_items)]
+            _load_from_socket(self._port, self._secret)
 
     def getTaskInfos(self):
         """
diff --git a/python/pyspark/tests/test_taskcontext.py 
b/python/pyspark/tests/test_taskcontext.py
index f5dbd06..6095a38 100644
--- a/python/pyspark/tests/test_taskcontext.py
+++ b/python/pyspark/tests/test_taskcontext.py
@@ -134,26 +134,6 @@ class TaskContextTests(PySparkTestCase):
         times = rdd.barrier().mapPartitions(f).map(context_barrier).collect()
         self.assertTrue(max(times) - min(times) < 1)
 
-    def test_all_gather(self):
-        """
-        Verify that BarrierTaskContext.allGather() performs global sync among 
all barrier tasks
-        within a stage and passes messages properly.
-        """
-        rdd = self.sc.parallelize(range(10), 4)
-
-        def f(iterator):
-            yield sum(iterator)
-
-        def context_barrier(x):
-            tc = BarrierTaskContext.get()
-            time.sleep(random.randint(1, 10))
-            out = tc.allGather(str(context.partitionId()))
-            pids = [int(e) for e in out]
-            return [pids]
-
-        pids = rdd.barrier().mapPartitions(f).map(context_barrier).collect()[0]
-        self.assertTrue(pids == [0, 1, 2, 3])
-
     def test_barrier_infos(self):
         """
         Verify that BarrierTaskContext.getTaskInfos() returns a list of all 
task infos in the


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to