This is an automated email from the ASF dual-hosted git repository. meng pushed a commit to branch branch-3.0 in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/branch-3.0 by this push: new f482187 [SPARK-30667][CORE] Add allGather method to BarrierTaskContext f482187 is described below commit f482187c127418d2ea538ac2551ae0fce1ddbc31 Author: sarthfrey-db <sarth.f...@databricks.com> AuthorDate: Thu Feb 13 16:15:00 2020 -0800 [SPARK-30667][CORE] Add allGather method to BarrierTaskContext ### What changes were proposed in this pull request? The `allGather` method is added to the `BarrierTaskContext`. This method contains the same functionality as the `BarrierTaskContext.barrier` method; it blocks the task until all tasks make the call, at which time they may continue execution. In addition, the `allGather` method takes an input message. Upon returning from the `allGather` the task receives a list of all the messages sent by all the tasks that made the `allGather` call. ### Why are the changes needed? There are many situations where having the tasks communicate in a synchronized way is useful. One simple example is if each task needs to start a server to serve requests from one another; first the tasks must find a free port (the result of which is undetermined beforehand) and then start making requests, but to do so they each must know the port chosen by the other task. An `allGather` method would allow them to inform each other of the port they will run on. ### Does this PR introduce any user-facing change? Yes, an `BarrierTaskContext.allGather` method will be available through the Scala, Java, and Python APIs. ### How was this patch tested? Most of the code path is already covered by tests to the `barrier` method, since this PR includes a refactor so that much code is shared by the `barrier` and `allGather` methods. However, a test is added to assert that an all gather on each tasks partition ID will return a list of every partition ID. An example through the Python API: ```python >>> from pyspark import BarrierTaskContext >>> >>> def f(iterator): ... context = BarrierTaskContext.get() ... return [context.allGather('{}'.format(context.partitionId()))] ... >>> sc.parallelize(range(4), 4).barrier().mapPartitions(f).collect()[0] [u'3', u'1', u'0', u'2'] ``` Closes #27395 from sarthfrey/master. Lead-authored-by: sarthfrey-db <sarth.f...@databricks.com> Co-authored-by: sarthfrey <sarth.f...@gmail.com> Signed-off-by: Xiangrui Meng <m...@databricks.com> (cherry picked from commit 57254c9719f9af9ad985596ed7fbbaafa4052002) Signed-off-by: Xiangrui Meng <m...@databricks.com> --- .../org/apache/spark/BarrierCoordinator.scala | 113 +++++++++++++-- .../org/apache/spark/BarrierTaskContext.scala | 153 ++++++++++++++------- .../org/apache/spark/api/python/PythonRunner.scala | 51 +++++-- .../spark/scheduler/BarrierTaskContextSuite.scala | 74 ++++++++++ python/pyspark/taskcontext.py | 49 ++++++- python/pyspark/tests/test_taskcontext.py | 20 +++ 6 files changed, 381 insertions(+), 79 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/BarrierCoordinator.scala b/core/src/main/scala/org/apache/spark/BarrierCoordinator.scala index 4e41767..042a266 100644 --- a/core/src/main/scala/org/apache/spark/BarrierCoordinator.scala +++ b/core/src/main/scala/org/apache/spark/BarrierCoordinator.scala @@ -17,12 +17,17 @@ package org.apache.spark +import java.nio.charset.StandardCharsets.UTF_8 import java.util.{Timer, TimerTask} import java.util.concurrent.ConcurrentHashMap import java.util.function.Consumer import scala.collection.mutable.ArrayBuffer +import org.json4s.JsonAST._ +import org.json4s.JsonDSL._ +import org.json4s.jackson.JsonMethods.{compact, render} + import org.apache.spark.internal.Logging import org.apache.spark.rpc.{RpcCallContext, RpcEnv, ThreadSafeRpcEndpoint} import org.apache.spark.scheduler.{LiveListenerBus, SparkListener, SparkListenerStageCompleted} @@ -99,10 +104,15 @@ private[spark] class BarrierCoordinator( // reset when a barrier() call fails due to timeout. private var barrierEpoch: Int = 0 - // An array of RPCCallContexts for barrier tasks that are waiting for reply of a barrier() - // call. + // An Array of RPCCallContexts for barrier tasks that have made a blocking runBarrier() call private val requesters: ArrayBuffer[RpcCallContext] = new ArrayBuffer[RpcCallContext](numTasks) + // An Array of allGather messages for barrier tasks that have made a blocking runBarrier() call + private val allGatherMessages: ArrayBuffer[String] = new Array[String](numTasks).to[ArrayBuffer] + + // The blocking requestMethod called by tasks to sync up for this stage attempt + private var requestMethodToSync: RequestMethod.Value = RequestMethod.BARRIER + // A timer task that ensures we may timeout for a barrier() call. private var timerTask: TimerTask = null @@ -130,9 +140,32 @@ private[spark] class BarrierCoordinator( // Process the global sync request. The barrier() call succeed if collected enough requests // within a configured time, otherwise fail all the pending requests. - def handleRequest(requester: RpcCallContext, request: RequestToSync): Unit = synchronized { + def handleRequest( + requester: RpcCallContext, + request: RequestToSync + ): Unit = synchronized { val taskId = request.taskAttemptId val epoch = request.barrierEpoch + val requestMethod = request.requestMethod + val partitionId = request.partitionId + val allGatherMessage = request match { + case ag: AllGatherRequestToSync => ag.allGatherMessage + case _ => "" + } + + if (requesters.size == 0) { + requestMethodToSync = requestMethod + } + + if (requestMethodToSync != requestMethod) { + requesters.foreach( + _.sendFailure(new SparkException(s"$barrierId tried to use requestMethod " + + s"`$requestMethod` during barrier epoch $barrierEpoch, which does not match " + + s"the current synchronized requestMethod `$requestMethodToSync`" + )) + ) + cleanupBarrierStage(barrierId) + } // Require the number of tasks is correctly set from the BarrierTaskContext. require(request.numTasks == numTasks, s"Number of tasks of $barrierId is " + @@ -153,6 +186,7 @@ private[spark] class BarrierCoordinator( } // Add the requester to array of RPCCallContexts pending for reply. requesters += requester + allGatherMessages(partitionId) = allGatherMessage logInfo(s"Barrier sync epoch $barrierEpoch from $barrierId received update from Task " + s"$taskId, current progress: ${requesters.size}/$numTasks.") if (maybeFinishAllRequesters(requesters, numTasks)) { @@ -162,6 +196,7 @@ private[spark] class BarrierCoordinator( s"tasks, finished successfully.") barrierEpoch += 1 requesters.clear() + allGatherMessages.clear() cancelTimerTask() } } @@ -173,7 +208,13 @@ private[spark] class BarrierCoordinator( requesters: ArrayBuffer[RpcCallContext], numTasks: Int): Boolean = { if (requesters.size == numTasks) { - requesters.foreach(_.reply(())) + requestMethodToSync match { + case RequestMethod.BARRIER => + requesters.foreach(_.reply("")) + case RequestMethod.ALL_GATHER => + val json: String = compact(render(allGatherMessages)) + requesters.foreach(_.reply(json)) + } true } else { false @@ -186,6 +227,7 @@ private[spark] class BarrierCoordinator( // messages come from current stage attempt shall fail. barrierEpoch = -1 requesters.clear() + allGatherMessages.clear() cancelTimerTask() } } @@ -199,11 +241,11 @@ private[spark] class BarrierCoordinator( } override def receiveAndReply(context: RpcCallContext): PartialFunction[Any, Unit] = { - case request @ RequestToSync(numTasks, stageId, stageAttemptId, _, _) => + case request: RequestToSync => // Get or init the ContextBarrierState correspond to the stage attempt. - val barrierId = ContextBarrierId(stageId, stageAttemptId) + val barrierId = ContextBarrierId(request.stageId, request.stageAttemptId) states.computeIfAbsent(barrierId, - (key: ContextBarrierId) => new ContextBarrierState(key, numTasks)) + (key: ContextBarrierId) => new ContextBarrierState(key, request.numTasks)) val barrierState = states.get(barrierId) barrierState.handleRequest(context, request) @@ -216,6 +258,16 @@ private[spark] class BarrierCoordinator( private[spark] sealed trait BarrierCoordinatorMessage extends Serializable +private[spark] sealed trait RequestToSync extends BarrierCoordinatorMessage { + def numTasks: Int + def stageId: Int + def stageAttemptId: Int + def taskAttemptId: Long + def barrierEpoch: Int + def partitionId: Int + def requestMethod: RequestMethod.Value +} + /** * A global sync request message from BarrierTaskContext, by `barrier()` call. Each request is * identified by stageId + stageAttemptId + barrierEpoch. @@ -224,11 +276,44 @@ private[spark] sealed trait BarrierCoordinatorMessage extends Serializable * @param stageId ID of current stage * @param stageAttemptId ID of current stage attempt * @param taskAttemptId Unique ID of current task - * @param barrierEpoch ID of the `barrier()` call, a task may consist multiple `barrier()` calls. + * @param barrierEpoch ID of the `barrier()` call, a task may consist multiple `barrier()` calls + * @param partitionId ID of the current partition the task is assigned to + * @param requestMethod The BarrierTaskContext method that was called to trigger BarrierCoordinator */ -private[spark] case class RequestToSync( - numTasks: Int, - stageId: Int, - stageAttemptId: Int, - taskAttemptId: Long, - barrierEpoch: Int) extends BarrierCoordinatorMessage +private[spark] case class BarrierRequestToSync( + numTasks: Int, + stageId: Int, + stageAttemptId: Int, + taskAttemptId: Long, + barrierEpoch: Int, + partitionId: Int, + requestMethod: RequestMethod.Value +) extends RequestToSync + +/** + * A global sync request message from BarrierTaskContext, by `allGather()` call. Each request is + * identified by stageId + stageAttemptId + barrierEpoch. + * + * @param numTasks The number of global sync requests the BarrierCoordinator shall receive + * @param stageId ID of current stage + * @param stageAttemptId ID of current stage attempt + * @param taskAttemptId Unique ID of current task + * @param barrierEpoch ID of the `barrier()` call, a task may consist multiple `barrier()` calls + * @param partitionId ID of the current partition the task is assigned to + * @param requestMethod The BarrierTaskContext method that was called to trigger BarrierCoordinator + * @param allGatherMessage Message sent from the BarrierTaskContext if requestMethod is ALL_GATHER + */ +private[spark] case class AllGatherRequestToSync( + numTasks: Int, + stageId: Int, + stageAttemptId: Int, + taskAttemptId: Long, + barrierEpoch: Int, + partitionId: Int, + requestMethod: RequestMethod.Value, + allGatherMessage: String +) extends RequestToSync + +private[spark] object RequestMethod extends Enumeration { + val BARRIER, ALL_GATHER = Value +} diff --git a/core/src/main/scala/org/apache/spark/BarrierTaskContext.scala b/core/src/main/scala/org/apache/spark/BarrierTaskContext.scala index 3d36980..2263538 100644 --- a/core/src/main/scala/org/apache/spark/BarrierTaskContext.scala +++ b/core/src/main/scala/org/apache/spark/BarrierTaskContext.scala @@ -17,11 +17,19 @@ package org.apache.spark +import java.nio.charset.StandardCharsets.UTF_8 import java.util.{Properties, Timer, TimerTask} import scala.collection.JavaConverters._ +import scala.collection.mutable.ArrayBuffer import scala.concurrent.TimeoutException import scala.concurrent.duration._ +import scala.language.postfixOps + +import org.json4s.DefaultFormats +import org.json4s.JsonAST._ +import org.json4s.JsonDSL._ +import org.json4s.jackson.JsonMethods.parse import org.apache.spark.annotation.{Experimental, Since} import org.apache.spark.executor.TaskMetrics @@ -59,49 +67,31 @@ class BarrierTaskContext private[spark] ( // from different tasks within the same barrier stage attempt to succeed. private lazy val numTasks = getTaskInfos().size - /** - * :: Experimental :: - * Sets a global barrier and waits until all tasks in this stage hit this barrier. Similar to - * MPI_Barrier function in MPI, the barrier() function call blocks until all tasks in the same - * stage have reached this routine. - * - * CAUTION! In a barrier stage, each task must have the same number of barrier() calls, in all - * possible code branches. Otherwise, you may get the job hanging or a SparkException after - * timeout. Some examples of '''misuses''' are listed below: - * 1. Only call barrier() function on a subset of all the tasks in the same barrier stage, it - * shall lead to timeout of the function call. - * {{{ - * rdd.barrier().mapPartitions { iter => - * val context = BarrierTaskContext.get() - * if (context.partitionId() == 0) { - * // Do nothing. - * } else { - * context.barrier() - * } - * iter - * } - * }}} - * - * 2. Include barrier() function in a try-catch code block, this may lead to timeout of the - * second function call. - * {{{ - * rdd.barrier().mapPartitions { iter => - * val context = BarrierTaskContext.get() - * try { - * // Do something that might throw an Exception. - * doSomething() - * context.barrier() - * } catch { - * case e: Exception => logWarning("...", e) - * } - * context.barrier() - * iter - * } - * }}} - */ - @Experimental - @Since("2.4.0") - def barrier(): Unit = { + private def getRequestToSync( + numTasks: Int, + stageId: Int, + stageAttemptNumber: Int, + taskAttemptId: Long, + barrierEpoch: Int, + partitionId: Int, + requestMethod: RequestMethod.Value, + allGatherMessage: String + ): RequestToSync = { + requestMethod match { + case RequestMethod.BARRIER => + BarrierRequestToSync(numTasks, stageId, stageAttemptNumber, taskAttemptId, + barrierEpoch, partitionId, requestMethod) + case RequestMethod.ALL_GATHER => + AllGatherRequestToSync(numTasks, stageId, stageAttemptNumber, taskAttemptId, + barrierEpoch, partitionId, requestMethod, allGatherMessage) + } + } + + private def runBarrier( + requestMethod: RequestMethod.Value, + allGatherMessage: String = "" + ): String = { + logInfo(s"Task $taskAttemptId from Stage $stageId(Attempt $stageAttemptNumber) has entered " + s"the global sync, current barrier epoch is $barrierEpoch.") logTrace("Current callSite: " + Utils.getCallSite()) @@ -118,10 +108,12 @@ class BarrierTaskContext private[spark] ( // Log the update of global sync every 60 seconds. timer.schedule(timerTask, 60000, 60000) + var json: String = "" + try { - val abortableRpcFuture = barrierCoordinator.askAbortable[Unit]( - message = RequestToSync(numTasks, stageId, stageAttemptNumber, taskAttemptId, - barrierEpoch), + val abortableRpcFuture = barrierCoordinator.askAbortable[String]( + message = getRequestToSync(numTasks, stageId, stageAttemptNumber, + taskAttemptId, barrierEpoch, partitionId, requestMethod, allGatherMessage), // Set a fixed timeout for RPC here, so users shall get a SparkException thrown by // BarrierCoordinator on timeout, instead of RPCTimeoutException from the RPC framework. timeout = new RpcTimeout(365.days, "barrierTimeout")) @@ -133,7 +125,7 @@ class BarrierTaskContext private[spark] ( while (!abortableRpcFuture.toFuture.isCompleted) { // wait RPC future for at most 1 second try { - ThreadUtils.awaitResult(abortableRpcFuture.toFuture, 1.second) + json = ThreadUtils.awaitResult(abortableRpcFuture.toFuture, 1.second) } catch { case _: TimeoutException | _: InterruptedException => // If `TimeoutException` thrown, waiting RPC future reach 1 second. @@ -163,6 +155,73 @@ class BarrierTaskContext private[spark] ( timerTask.cancel() timer.purge() } + json + } + + /** + * :: Experimental :: + * Sets a global barrier and waits until all tasks in this stage hit this barrier. Similar to + * MPI_Barrier function in MPI, the barrier() function call blocks until all tasks in the same + * stage have reached this routine. + * + * CAUTION! In a barrier stage, each task must have the same number of barrier() calls, in all + * possible code branches. Otherwise, you may get the job hanging or a SparkException after + * timeout. Some examples of '''misuses''' are listed below: + * 1. Only call barrier() function on a subset of all the tasks in the same barrier stage, it + * shall lead to timeout of the function call. + * {{{ + * rdd.barrier().mapPartitions { iter => + * val context = BarrierTaskContext.get() + * if (context.partitionId() == 0) { + * // Do nothing. + * } else { + * context.barrier() + * } + * iter + * } + * }}} + * + * 2. Include barrier() function in a try-catch code block, this may lead to timeout of the + * second function call. + * {{{ + * rdd.barrier().mapPartitions { iter => + * val context = BarrierTaskContext.get() + * try { + * // Do something that might throw an Exception. + * doSomething() + * context.barrier() + * } catch { + * case e: Exception => logWarning("...", e) + * } + * context.barrier() + * iter + * } + * }}} + */ + @Experimental + @Since("2.4.0") + def barrier(): Unit = { + runBarrier(RequestMethod.BARRIER) + () + } + + /** + * :: Experimental :: + * Blocks until all tasks in the same stage have reached this routine. Each task passes in + * a message and returns with a list of all the messages passed in by each of those tasks. + * + * CAUTION! The allGather method requires the same precautions as the barrier method + * + * The message is type String rather than Array[Byte] because it is more convenient for + * the user at the cost of worse performance. + */ + @Experimental + @Since("3.0.0") + def allGather(message: String): ArrayBuffer[String] = { + val json = runBarrier(RequestMethod.ALL_GATHER, message) + val jsonArray = parse(json) + implicit val formats = DefaultFormats + ArrayBuffer(jsonArray.extract[Array[String]]: _*) } /** diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala b/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala index 658e0d5..fa8bf0f 100644 --- a/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala +++ b/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala @@ -24,8 +24,13 @@ import java.nio.charset.StandardCharsets.UTF_8 import java.util.concurrent.atomic.AtomicBoolean import scala.collection.JavaConverters._ +import scala.collection.mutable.ArrayBuffer import scala.util.control.NonFatal +import org.json4s.JsonAST._ +import org.json4s.JsonDSL._ +import org.json4s.jackson.JsonMethods.{compact, render} + import org.apache.spark._ import org.apache.spark.internal.Logging import org.apache.spark.internal.config.{BUFFER_SIZE, EXECUTOR_CORES} @@ -238,13 +243,18 @@ private[spark] abstract class BasePythonRunner[IN, OUT]( sock.setSoTimeout(10000) authHelper.authClient(sock) val input = new DataInputStream(sock.getInputStream()) - input.readInt() match { + val requestMethod = input.readInt() + // The BarrierTaskContext function may wait infinitely, socket shall not timeout + // before the function finishes. + sock.setSoTimeout(0) + requestMethod match { case BarrierTaskContextMessageProtocol.BARRIER_FUNCTION => - // The barrier() function may wait infinitely, socket shall not timeout - // before the function finishes. - sock.setSoTimeout(0) - barrierAndServe(sock) - + barrierAndServe(requestMethod, sock) + case BarrierTaskContextMessageProtocol.ALL_GATHER_FUNCTION => + val length = input.readInt() + val message = new Array[Byte](length) + input.readFully(message) + barrierAndServe(requestMethod, sock, new String(message, UTF_8)) case _ => val out = new DataOutputStream(new BufferedOutputStream( sock.getOutputStream)) @@ -395,15 +405,31 @@ private[spark] abstract class BasePythonRunner[IN, OUT]( } /** - * Gateway to call BarrierTaskContext.barrier(). + * Gateway to call BarrierTaskContext methods. */ - def barrierAndServe(sock: Socket): Unit = { - require(serverSocket.isDefined, "No available ServerSocket to redirect the barrier() call.") - + def barrierAndServe(requestMethod: Int, sock: Socket, message: String = ""): Unit = { + require( + serverSocket.isDefined, + "No available ServerSocket to redirect the BarrierTaskContext method call." + ) val out = new DataOutputStream(new BufferedOutputStream(sock.getOutputStream)) try { - context.asInstanceOf[BarrierTaskContext].barrier() - writeUTF(BarrierTaskContextMessageProtocol.BARRIER_RESULT_SUCCESS, out) + var result: String = "" + requestMethod match { + case BarrierTaskContextMessageProtocol.BARRIER_FUNCTION => + context.asInstanceOf[BarrierTaskContext].barrier() + result = BarrierTaskContextMessageProtocol.BARRIER_RESULT_SUCCESS + case BarrierTaskContextMessageProtocol.ALL_GATHER_FUNCTION => + val messages: ArrayBuffer[String] = context.asInstanceOf[BarrierTaskContext].allGather( + message + ) + result = compact(render(JArray( + messages.map( + (message) => JString(message) + ).toList + ))) + } + writeUTF(result, out) } catch { case e: SparkException => writeUTF(e.getMessage, out) @@ -638,6 +664,7 @@ private[spark] object SpecialLengths { private[spark] object BarrierTaskContextMessageProtocol { val BARRIER_FUNCTION = 1 + val ALL_GATHER_FUNCTION = 2 val BARRIER_RESULT_SUCCESS = "success" val ERROR_UNRECOGNIZED_FUNCTION = "Not recognized function call from python side." } diff --git a/core/src/test/scala/org/apache/spark/scheduler/BarrierTaskContextSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/BarrierTaskContextSuite.scala index fc8ac38..ed38b7f 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/BarrierTaskContextSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/BarrierTaskContextSuite.scala @@ -19,6 +19,7 @@ package org.apache.spark.scheduler import java.io.File +import scala.collection.mutable.ArrayBuffer import scala.util.Random import org.apache.spark._ @@ -52,6 +53,79 @@ class BarrierTaskContextSuite extends SparkFunSuite with LocalSparkContext { assert(times.max - times.min <= 1000) } + test("share messages with allGather() call") { + val conf = new SparkConf() + .setMaster("local-cluster[4, 1, 1024]") + .setAppName("test-cluster") + sc = new SparkContext(conf) + val rdd = sc.makeRDD(1 to 10, 4) + val rdd2 = rdd.barrier().mapPartitions { it => + val context = BarrierTaskContext.get() + // Sleep for a random time before global sync. + Thread.sleep(Random.nextInt(1000)) + // Pass partitionId message in + val message = context.partitionId().toString + val messages = context.allGather(message) + messages.toList.iterator + } + // Take a sorted list of all the partitionId messages + val messages = rdd2.collect().head + // All the task partitionIds are shared + for((x, i) <- messages.view.zipWithIndex) assert(x == i.toString) + } + + test("throw exception if we attempt to synchronize with different blocking calls") { + val conf = new SparkConf() + .setMaster("local-cluster[4, 1, 1024]") + .setAppName("test-cluster") + sc = new SparkContext(conf) + val rdd = sc.makeRDD(1 to 10, 4) + val rdd2 = rdd.barrier().mapPartitions { it => + val context = BarrierTaskContext.get() + val partitionId = context.partitionId + if (partitionId == 0) { + context.barrier() + } else { + context.allGather(partitionId.toString) + } + Seq(null).iterator + } + val error = intercept[SparkException] { + rdd2.collect() + }.getMessage + assert(error.contains("does not match the current synchronized requestMethod")) + } + + test("successively sync with allGather and barrier") { + val conf = new SparkConf() + .setMaster("local-cluster[4, 1, 1024]") + .setAppName("test-cluster") + sc = new SparkContext(conf) + val rdd = sc.makeRDD(1 to 10, 4) + val rdd2 = rdd.barrier().mapPartitions { it => + val context = BarrierTaskContext.get() + // Sleep for a random time before global sync. + Thread.sleep(Random.nextInt(1000)) + context.barrier() + val time1 = System.currentTimeMillis() + // Sleep for a random time before global sync. + Thread.sleep(Random.nextInt(1000)) + // Pass partitionId message in + val message = context.partitionId().toString + val messages = context.allGather(message) + val time2 = System.currentTimeMillis() + Seq((time1, time2)).iterator + } + val times = rdd2.collect() + // All the tasks shall finish the first round of global sync within a short time slot. + val times1 = times.map(_._1) + assert(times1.max - times1.min <= 1000) + + // All the tasks shall finish the second round of global sync within a short time slot. + val times2 = times.map(_._2) + assert(times2.max - times2.min <= 1000) + } + test("support multiple barrier() call within a single task") { initLocalClusterSparkContext() val rdd = sc.makeRDD(1 to 10, 4) diff --git a/python/pyspark/taskcontext.py b/python/pyspark/taskcontext.py index d648f63..90bd234 100644 --- a/python/pyspark/taskcontext.py +++ b/python/pyspark/taskcontext.py @@ -16,9 +16,10 @@ # from __future__ import print_function +import json from pyspark.java_gateway import local_connect_and_auth -from pyspark.serializers import write_int, UTF8Deserializer +from pyspark.serializers import write_int, write_with_length, UTF8Deserializer class TaskContext(object): @@ -107,18 +108,28 @@ class TaskContext(object): BARRIER_FUNCTION = 1 +ALL_GATHER_FUNCTION = 2 -def _load_from_socket(port, auth_secret): +def _load_from_socket(port, auth_secret, function, all_gather_message=None): """ Load data from a given socket, this is a blocking method thus only return when the socket connection has been closed. """ (sockfile, sock) = local_connect_and_auth(port, auth_secret) - # The barrier() call may block forever, so no timeout + + # The call may block forever, so no timeout sock.settimeout(None) - # Make a barrier() function call. - write_int(BARRIER_FUNCTION, sockfile) + + if function == BARRIER_FUNCTION: + # Make a barrier() function call. + write_int(function, sockfile) + elif function == ALL_GATHER_FUNCTION: + # Make a all_gather() function call. + write_int(function, sockfile) + write_with_length(all_gather_message.encode("utf-8"), sockfile) + else: + raise ValueError("Unrecognized function type") sockfile.flush() # Collect result. @@ -199,7 +210,33 @@ class BarrierTaskContext(TaskContext): raise Exception("Not supported to call barrier() before initialize " + "BarrierTaskContext.") else: - _load_from_socket(self._port, self._secret) + _load_from_socket(self._port, self._secret, BARRIER_FUNCTION) + + def allGather(self, message=""): + """ + .. note:: Experimental + + This function blocks until all tasks in the same stage have reached this routine. + Each task passes in a message and returns with a list of all the messages passed in + by each of those tasks. + + .. warning:: In a barrier stage, each task much have the same number of `allGather()` + calls, in all possible code branches. + Otherwise, you may get the job hanging or a SparkException after timeout. + """ + if not isinstance(message, str): + raise ValueError("Argument `message` must be of type `str`") + elif self._port is None or self._secret is None: + raise Exception("Not supported to call barrier() before initialize " + + "BarrierTaskContext.") + else: + gathered_items = _load_from_socket( + self._port, + self._secret, + ALL_GATHER_FUNCTION, + message, + ) + return [e for e in json.loads(gathered_items)] def getTaskInfos(self): """ diff --git a/python/pyspark/tests/test_taskcontext.py b/python/pyspark/tests/test_taskcontext.py index 6095a38..f5dbd06 100644 --- a/python/pyspark/tests/test_taskcontext.py +++ b/python/pyspark/tests/test_taskcontext.py @@ -134,6 +134,26 @@ class TaskContextTests(PySparkTestCase): times = rdd.barrier().mapPartitions(f).map(context_barrier).collect() self.assertTrue(max(times) - min(times) < 1) + def test_all_gather(self): + """ + Verify that BarrierTaskContext.allGather() performs global sync among all barrier tasks + within a stage and passes messages properly. + """ + rdd = self.sc.parallelize(range(10), 4) + + def f(iterator): + yield sum(iterator) + + def context_barrier(x): + tc = BarrierTaskContext.get() + time.sleep(random.randint(1, 10)) + out = tc.allGather(str(context.partitionId())) + pids = [int(e) for e in out] + return [pids] + + pids = rdd.barrier().mapPartitions(f).map(context_barrier).collect()[0] + self.assertTrue(pids == [0, 1, 2, 3]) + def test_barrier_infos(self): """ Verify that BarrierTaskContext.getTaskInfos() returns a list of all task infos in the --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org