Repository: spark Updated Branches: refs/heads/master 7bca62f79 -> f15806a8f
http://git-wip-us.apache.org/repos/asf/spark/blob/f15806a8/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterActor.scala ---------------------------------------------------------------------- diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterActor.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterActor.scala deleted file mode 100644 index 5b53280..0000000 --- a/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterActor.scala +++ /dev/null @@ -1,512 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.storage - -import java.util.{HashMap => JHashMap} - -import scala.collection.mutable -import scala.collection.JavaConversions._ -import scala.concurrent.Future -import scala.concurrent.duration._ - -import akka.actor.{Actor, ActorRef} -import akka.pattern.ask - -import org.apache.spark.{Logging, SparkConf, SparkException} -import org.apache.spark.annotation.DeveloperApi -import org.apache.spark.scheduler._ -import org.apache.spark.storage.BlockManagerMessages._ -import org.apache.spark.util.{ActorLogReceive, AkkaUtils, Utils} - -/** - * BlockManagerMasterActor is an actor on the master node to track statuses of - * all slaves' block managers. - */ -private[spark] -class BlockManagerMasterActor(val isLocal: Boolean, conf: SparkConf, listenerBus: LiveListenerBus) - extends Actor with ActorLogReceive with Logging { - - // Mapping from block manager id to the block manager's information. - private val blockManagerInfo = new mutable.HashMap[BlockManagerId, BlockManagerInfo] - - // Mapping from executor ID to block manager ID. - private val blockManagerIdByExecutor = new mutable.HashMap[String, BlockManagerId] - - // Mapping from block id to the set of block managers that have the block. - private val blockLocations = new JHashMap[BlockId, mutable.HashSet[BlockManagerId]] - - private val akkaTimeout = AkkaUtils.askTimeout(conf) - - override def receiveWithLogging: PartialFunction[Any, Unit] = { - case RegisterBlockManager(blockManagerId, maxMemSize, slaveActor) => - register(blockManagerId, maxMemSize, slaveActor) - sender ! true - - case UpdateBlockInfo( - blockManagerId, blockId, storageLevel, deserializedSize, size, tachyonSize) => - sender ! updateBlockInfo( - blockManagerId, blockId, storageLevel, deserializedSize, size, tachyonSize) - - case GetLocations(blockId) => - sender ! getLocations(blockId) - - case GetLocationsMultipleBlockIds(blockIds) => - sender ! getLocationsMultipleBlockIds(blockIds) - - case GetPeers(blockManagerId) => - sender ! getPeers(blockManagerId) - - case GetActorSystemHostPortForExecutor(executorId) => - sender ! getActorSystemHostPortForExecutor(executorId) - - case GetMemoryStatus => - sender ! memoryStatus - - case GetStorageStatus => - sender ! storageStatus - - case GetBlockStatus(blockId, askSlaves) => - sender ! blockStatus(blockId, askSlaves) - - case GetMatchingBlockIds(filter, askSlaves) => - sender ! getMatchingBlockIds(filter, askSlaves) - - case RemoveRdd(rddId) => - sender ! removeRdd(rddId) - - case RemoveShuffle(shuffleId) => - sender ! removeShuffle(shuffleId) - - case RemoveBroadcast(broadcastId, removeFromDriver) => - sender ! removeBroadcast(broadcastId, removeFromDriver) - - case RemoveBlock(blockId) => - removeBlockFromWorkers(blockId) - sender ! true - - case RemoveExecutor(execId) => - removeExecutor(execId) - sender ! true - - case StopBlockManagerMaster => - sender ! true - context.stop(self) - - case BlockManagerHeartbeat(blockManagerId) => - sender ! heartbeatReceived(blockManagerId) - - case other => - logWarning("Got unknown message: " + other) - } - - private def removeRdd(rddId: Int): Future[Seq[Int]] = { - // First remove the metadata for the given RDD, and then asynchronously remove the blocks - // from the slaves. - - // Find all blocks for the given RDD, remove the block from both blockLocations and - // the blockManagerInfo that is tracking the blocks. - val blocks = blockLocations.keys.flatMap(_.asRDDId).filter(_.rddId == rddId) - blocks.foreach { blockId => - val bms: mutable.HashSet[BlockManagerId] = blockLocations.get(blockId) - bms.foreach(bm => blockManagerInfo.get(bm).foreach(_.removeBlock(blockId))) - blockLocations.remove(blockId) - } - - // Ask the slaves to remove the RDD, and put the result in a sequence of Futures. - // The dispatcher is used as an implicit argument into the Future sequence construction. - import context.dispatcher - val removeMsg = RemoveRdd(rddId) - Future.sequence( - blockManagerInfo.values.map { bm => - bm.slaveActor.ask(removeMsg)(akkaTimeout).mapTo[Int] - }.toSeq - ) - } - - private def removeShuffle(shuffleId: Int): Future[Seq[Boolean]] = { - // Nothing to do in the BlockManagerMasterActor data structures - import context.dispatcher - val removeMsg = RemoveShuffle(shuffleId) - Future.sequence( - blockManagerInfo.values.map { bm => - bm.slaveActor.ask(removeMsg)(akkaTimeout).mapTo[Boolean] - }.toSeq - ) - } - - /** - * Delegate RemoveBroadcast messages to each BlockManager because the master may not notified - * of all broadcast blocks. If removeFromDriver is false, broadcast blocks are only removed - * from the executors, but not from the driver. - */ - private def removeBroadcast(broadcastId: Long, removeFromDriver: Boolean): Future[Seq[Int]] = { - import context.dispatcher - val removeMsg = RemoveBroadcast(broadcastId, removeFromDriver) - val requiredBlockManagers = blockManagerInfo.values.filter { info => - removeFromDriver || !info.blockManagerId.isDriver - } - Future.sequence( - requiredBlockManagers.map { bm => - bm.slaveActor.ask(removeMsg)(akkaTimeout).mapTo[Int] - }.toSeq - ) - } - - private def removeBlockManager(blockManagerId: BlockManagerId) { - val info = blockManagerInfo(blockManagerId) - - // Remove the block manager from blockManagerIdByExecutor. - blockManagerIdByExecutor -= blockManagerId.executorId - - // Remove it from blockManagerInfo and remove all the blocks. - blockManagerInfo.remove(blockManagerId) - val iterator = info.blocks.keySet.iterator - while (iterator.hasNext) { - val blockId = iterator.next - val locations = blockLocations.get(blockId) - locations -= blockManagerId - if (locations.size == 0) { - blockLocations.remove(blockId) - } - } - listenerBus.post(SparkListenerBlockManagerRemoved(System.currentTimeMillis(), blockManagerId)) - logInfo(s"Removing block manager $blockManagerId") - } - - private def removeExecutor(execId: String) { - logInfo("Trying to remove executor " + execId + " from BlockManagerMaster.") - blockManagerIdByExecutor.get(execId).foreach(removeBlockManager) - } - - /** - * Return true if the driver knows about the given block manager. Otherwise, return false, - * indicating that the block manager should re-register. - */ - private def heartbeatReceived(blockManagerId: BlockManagerId): Boolean = { - if (!blockManagerInfo.contains(blockManagerId)) { - blockManagerId.isDriver && !isLocal - } else { - blockManagerInfo(blockManagerId).updateLastSeenMs() - true - } - } - - // Remove a block from the slaves that have it. This can only be used to remove - // blocks that the master knows about. - private def removeBlockFromWorkers(blockId: BlockId) { - val locations = blockLocations.get(blockId) - if (locations != null) { - locations.foreach { blockManagerId: BlockManagerId => - val blockManager = blockManagerInfo.get(blockManagerId) - if (blockManager.isDefined) { - // Remove the block from the slave's BlockManager. - // Doesn't actually wait for a confirmation and the message might get lost. - // If message loss becomes frequent, we should add retry logic here. - blockManager.get.slaveActor.ask(RemoveBlock(blockId))(akkaTimeout) - } - } - } - } - - // Return a map from the block manager id to max memory and remaining memory. - private def memoryStatus: Map[BlockManagerId, (Long, Long)] = { - blockManagerInfo.map { case(blockManagerId, info) => - (blockManagerId, (info.maxMem, info.remainingMem)) - }.toMap - } - - private def storageStatus: Array[StorageStatus] = { - blockManagerInfo.map { case (blockManagerId, info) => - new StorageStatus(blockManagerId, info.maxMem, info.blocks) - }.toArray - } - - /** - * Return the block's status for all block managers, if any. NOTE: This is a - * potentially expensive operation and should only be used for testing. - * - * If askSlaves is true, the master queries each block manager for the most updated block - * statuses. This is useful when the master is not informed of the given block by all block - * managers. - */ - private def blockStatus( - blockId: BlockId, - askSlaves: Boolean): Map[BlockManagerId, Future[Option[BlockStatus]]] = { - import context.dispatcher - val getBlockStatus = GetBlockStatus(blockId) - /* - * Rather than blocking on the block status query, master actor should simply return - * Futures to avoid potential deadlocks. This can arise if there exists a block manager - * that is also waiting for this master actor's response to a previous message. - */ - blockManagerInfo.values.map { info => - val blockStatusFuture = - if (askSlaves) { - info.slaveActor.ask(getBlockStatus)(akkaTimeout).mapTo[Option[BlockStatus]] - } else { - Future { info.getStatus(blockId) } - } - (info.blockManagerId, blockStatusFuture) - }.toMap - } - - /** - * Return the ids of blocks present in all the block managers that match the given filter. - * NOTE: This is a potentially expensive operation and should only be used for testing. - * - * If askSlaves is true, the master queries each block manager for the most updated block - * statuses. This is useful when the master is not informed of the given block by all block - * managers. - */ - private def getMatchingBlockIds( - filter: BlockId => Boolean, - askSlaves: Boolean): Future[Seq[BlockId]] = { - import context.dispatcher - val getMatchingBlockIds = GetMatchingBlockIds(filter) - Future.sequence( - blockManagerInfo.values.map { info => - val future = - if (askSlaves) { - info.slaveActor.ask(getMatchingBlockIds)(akkaTimeout).mapTo[Seq[BlockId]] - } else { - Future { info.blocks.keys.filter(filter).toSeq } - } - future - } - ).map(_.flatten.toSeq) - } - - private def register(id: BlockManagerId, maxMemSize: Long, slaveActor: ActorRef) { - val time = System.currentTimeMillis() - if (!blockManagerInfo.contains(id)) { - blockManagerIdByExecutor.get(id.executorId) match { - case Some(oldId) => - // A block manager of the same executor already exists, so remove it (assumed dead) - logError("Got two different block manager registrations on same executor - " - + s" will replace old one $oldId with new one $id") - removeExecutor(id.executorId) - case None => - } - logInfo("Registering block manager %s with %s RAM, %s".format( - id.hostPort, Utils.bytesToString(maxMemSize), id)) - - blockManagerIdByExecutor(id.executorId) = id - - blockManagerInfo(id) = new BlockManagerInfo( - id, System.currentTimeMillis(), maxMemSize, slaveActor) - } - listenerBus.post(SparkListenerBlockManagerAdded(time, id, maxMemSize)) - } - - private def updateBlockInfo( - blockManagerId: BlockManagerId, - blockId: BlockId, - storageLevel: StorageLevel, - memSize: Long, - diskSize: Long, - tachyonSize: Long): Boolean = { - - if (!blockManagerInfo.contains(blockManagerId)) { - if (blockManagerId.isDriver && !isLocal) { - // We intentionally do not register the master (except in local mode), - // so we should not indicate failure. - return true - } else { - return false - } - } - - if (blockId == null) { - blockManagerInfo(blockManagerId).updateLastSeenMs() - return true - } - - blockManagerInfo(blockManagerId).updateBlockInfo( - blockId, storageLevel, memSize, diskSize, tachyonSize) - - var locations: mutable.HashSet[BlockManagerId] = null - if (blockLocations.containsKey(blockId)) { - locations = blockLocations.get(blockId) - } else { - locations = new mutable.HashSet[BlockManagerId] - blockLocations.put(blockId, locations) - } - - if (storageLevel.isValid) { - locations.add(blockManagerId) - } else { - locations.remove(blockManagerId) - } - - // Remove the block from master tracking if it has been removed on all slaves. - if (locations.size == 0) { - blockLocations.remove(blockId) - } - true - } - - private def getLocations(blockId: BlockId): Seq[BlockManagerId] = { - if (blockLocations.containsKey(blockId)) blockLocations.get(blockId).toSeq else Seq.empty - } - - private def getLocationsMultipleBlockIds(blockIds: Array[BlockId]): Seq[Seq[BlockManagerId]] = { - blockIds.map(blockId => getLocations(blockId)) - } - - /** Get the list of the peers of the given block manager */ - private def getPeers(blockManagerId: BlockManagerId): Seq[BlockManagerId] = { - val blockManagerIds = blockManagerInfo.keySet - if (blockManagerIds.contains(blockManagerId)) { - blockManagerIds.filterNot { _.isDriver }.filterNot { _ == blockManagerId }.toSeq - } else { - Seq.empty - } - } - - /** - * Returns the hostname and port of an executor's actor system, based on the Akka address of its - * BlockManagerSlaveActor. - */ - private def getActorSystemHostPortForExecutor(executorId: String): Option[(String, Int)] = { - for ( - blockManagerId <- blockManagerIdByExecutor.get(executorId); - info <- blockManagerInfo.get(blockManagerId); - host <- info.slaveActor.path.address.host; - port <- info.slaveActor.path.address.port - ) yield { - (host, port) - } - } -} - -@DeveloperApi -case class BlockStatus( - storageLevel: StorageLevel, - memSize: Long, - diskSize: Long, - tachyonSize: Long) { - def isCached: Boolean = memSize + diskSize + tachyonSize > 0 -} - -@DeveloperApi -object BlockStatus { - def empty: BlockStatus = BlockStatus(StorageLevel.NONE, 0L, 0L, 0L) -} - -private[spark] class BlockManagerInfo( - val blockManagerId: BlockManagerId, - timeMs: Long, - val maxMem: Long, - val slaveActor: ActorRef) - extends Logging { - - private var _lastSeenMs: Long = timeMs - private var _remainingMem: Long = maxMem - - // Mapping from block id to its status. - private val _blocks = new JHashMap[BlockId, BlockStatus] - - def getStatus(blockId: BlockId): Option[BlockStatus] = Option(_blocks.get(blockId)) - - def updateLastSeenMs() { - _lastSeenMs = System.currentTimeMillis() - } - - def updateBlockInfo( - blockId: BlockId, - storageLevel: StorageLevel, - memSize: Long, - diskSize: Long, - tachyonSize: Long) { - - updateLastSeenMs() - - if (_blocks.containsKey(blockId)) { - // The block exists on the slave already. - val blockStatus: BlockStatus = _blocks.get(blockId) - val originalLevel: StorageLevel = blockStatus.storageLevel - val originalMemSize: Long = blockStatus.memSize - - if (originalLevel.useMemory) { - _remainingMem += originalMemSize - } - } - - if (storageLevel.isValid) { - /* isValid means it is either stored in-memory, on-disk or on-Tachyon. - * The memSize here indicates the data size in or dropped from memory, - * tachyonSize here indicates the data size in or dropped from Tachyon, - * and the diskSize here indicates the data size in or dropped to disk. - * They can be both larger than 0, when a block is dropped from memory to disk. - * Therefore, a safe way to set BlockStatus is to set its info in accurate modes. */ - if (storageLevel.useMemory) { - _blocks.put(blockId, BlockStatus(storageLevel, memSize, 0, 0)) - _remainingMem -= memSize - logInfo("Added %s in memory on %s (size: %s, free: %s)".format( - blockId, blockManagerId.hostPort, Utils.bytesToString(memSize), - Utils.bytesToString(_remainingMem))) - } - if (storageLevel.useDisk) { - _blocks.put(blockId, BlockStatus(storageLevel, 0, diskSize, 0)) - logInfo("Added %s on disk on %s (size: %s)".format( - blockId, blockManagerId.hostPort, Utils.bytesToString(diskSize))) - } - if (storageLevel.useOffHeap) { - _blocks.put(blockId, BlockStatus(storageLevel, 0, 0, tachyonSize)) - logInfo("Added %s on tachyon on %s (size: %s)".format( - blockId, blockManagerId.hostPort, Utils.bytesToString(tachyonSize))) - } - } else if (_blocks.containsKey(blockId)) { - // If isValid is not true, drop the block. - val blockStatus: BlockStatus = _blocks.get(blockId) - _blocks.remove(blockId) - if (blockStatus.storageLevel.useMemory) { - logInfo("Removed %s on %s in memory (size: %s, free: %s)".format( - blockId, blockManagerId.hostPort, Utils.bytesToString(blockStatus.memSize), - Utils.bytesToString(_remainingMem))) - } - if (blockStatus.storageLevel.useDisk) { - logInfo("Removed %s on %s on disk (size: %s)".format( - blockId, blockManagerId.hostPort, Utils.bytesToString(blockStatus.diskSize))) - } - if (blockStatus.storageLevel.useOffHeap) { - logInfo("Removed %s on %s on tachyon (size: %s)".format( - blockId, blockManagerId.hostPort, Utils.bytesToString(blockStatus.tachyonSize))) - } - } - } - - def removeBlock(blockId: BlockId) { - if (_blocks.containsKey(blockId)) { - _remainingMem += _blocks.get(blockId).memSize - _blocks.remove(blockId) - } - } - - def remainingMem: Long = _remainingMem - - def lastSeenMs: Long = _lastSeenMs - - def blocks: JHashMap[BlockId, BlockStatus] = _blocks - - override def toString: String = "BlockManagerInfo " + timeMs + " " + _remainingMem - - def clear() { - _blocks.clear() - } -} http://git-wip-us.apache.org/repos/asf/spark/blob/f15806a8/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterEndpoint.scala ---------------------------------------------------------------------- diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterEndpoint.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterEndpoint.scala new file mode 100644 index 0000000..28c73a7 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterEndpoint.scala @@ -0,0 +1,509 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.storage + +import java.util.{HashMap => JHashMap} + +import scala.collection.mutable +import scala.collection.JavaConversions._ +import scala.concurrent.{ExecutionContext, Future} + +import org.apache.spark.rpc.{RpcEndpointRef, RpcEnv, RpcCallContext, ThreadSafeRpcEndpoint} +import org.apache.spark.{Logging, SparkConf} +import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.scheduler._ +import org.apache.spark.storage.BlockManagerMessages._ +import org.apache.spark.util.Utils + +/** + * BlockManagerMasterEndpoint is an [[ThreadSafeRpcEndpoint]] on the master node to track statuses + * of all slaves' block managers. + */ +private[spark] +class BlockManagerMasterEndpoint( + override val rpcEnv: RpcEnv, + val isLocal: Boolean, + conf: SparkConf, + listenerBus: LiveListenerBus) + extends ThreadSafeRpcEndpoint with Logging { + + // Mapping from block manager id to the block manager's information. + private val blockManagerInfo = new mutable.HashMap[BlockManagerId, BlockManagerInfo] + + // Mapping from executor ID to block manager ID. + private val blockManagerIdByExecutor = new mutable.HashMap[String, BlockManagerId] + + // Mapping from block id to the set of block managers that have the block. + private val blockLocations = new JHashMap[BlockId, mutable.HashSet[BlockManagerId]] + + private val askThreadPool = Utils.newDaemonCachedThreadPool("block-manager-ask-thread-pool") + private implicit val askExecutionContext = ExecutionContext.fromExecutorService(askThreadPool) + + override def receiveAndReply(context: RpcCallContext): PartialFunction[Any, Unit] = { + case RegisterBlockManager(blockManagerId, maxMemSize, slaveEndpoint) => + register(blockManagerId, maxMemSize, slaveEndpoint) + context.reply(true) + + case UpdateBlockInfo( + blockManagerId, blockId, storageLevel, deserializedSize, size, tachyonSize) => + context.reply(updateBlockInfo( + blockManagerId, blockId, storageLevel, deserializedSize, size, tachyonSize)) + + case GetLocations(blockId) => + context.reply(getLocations(blockId)) + + case GetLocationsMultipleBlockIds(blockIds) => + context.reply(getLocationsMultipleBlockIds(blockIds)) + + case GetPeers(blockManagerId) => + context.reply(getPeers(blockManagerId)) + + case GetRpcHostPortForExecutor(executorId) => + context.reply(getRpcHostPortForExecutor(executorId)) + + case GetMemoryStatus => + context.reply(memoryStatus) + + case GetStorageStatus => + context.reply(storageStatus) + + case GetBlockStatus(blockId, askSlaves) => + context.reply(blockStatus(blockId, askSlaves)) + + case GetMatchingBlockIds(filter, askSlaves) => + context.reply(getMatchingBlockIds(filter, askSlaves)) + + case RemoveRdd(rddId) => + context.reply(removeRdd(rddId)) + + case RemoveShuffle(shuffleId) => + context.reply(removeShuffle(shuffleId)) + + case RemoveBroadcast(broadcastId, removeFromDriver) => + context.reply(removeBroadcast(broadcastId, removeFromDriver)) + + case RemoveBlock(blockId) => + removeBlockFromWorkers(blockId) + context.reply(true) + + case RemoveExecutor(execId) => + removeExecutor(execId) + context.reply(true) + + case StopBlockManagerMaster => + context.reply(true) + stop() + + case BlockManagerHeartbeat(blockManagerId) => + context.reply(heartbeatReceived(blockManagerId)) + + } + + private def removeRdd(rddId: Int): Future[Seq[Int]] = { + // First remove the metadata for the given RDD, and then asynchronously remove the blocks + // from the slaves. + + // Find all blocks for the given RDD, remove the block from both blockLocations and + // the blockManagerInfo that is tracking the blocks. + val blocks = blockLocations.keys.flatMap(_.asRDDId).filter(_.rddId == rddId) + blocks.foreach { blockId => + val bms: mutable.HashSet[BlockManagerId] = blockLocations.get(blockId) + bms.foreach(bm => blockManagerInfo.get(bm).foreach(_.removeBlock(blockId))) + blockLocations.remove(blockId) + } + + // Ask the slaves to remove the RDD, and put the result in a sequence of Futures. + // The dispatcher is used as an implicit argument into the Future sequence construction. + val removeMsg = RemoveRdd(rddId) + Future.sequence( + blockManagerInfo.values.map { bm => + bm.slaveEndpoint.sendWithReply[Int](removeMsg) + }.toSeq + ) + } + + private def removeShuffle(shuffleId: Int): Future[Seq[Boolean]] = { + // Nothing to do in the BlockManagerMasterEndpoint data structures + val removeMsg = RemoveShuffle(shuffleId) + Future.sequence( + blockManagerInfo.values.map { bm => + bm.slaveEndpoint.sendWithReply[Boolean](removeMsg) + }.toSeq + ) + } + + /** + * Delegate RemoveBroadcast messages to each BlockManager because the master may not notified + * of all broadcast blocks. If removeFromDriver is false, broadcast blocks are only removed + * from the executors, but not from the driver. + */ + private def removeBroadcast(broadcastId: Long, removeFromDriver: Boolean): Future[Seq[Int]] = { + val removeMsg = RemoveBroadcast(broadcastId, removeFromDriver) + val requiredBlockManagers = blockManagerInfo.values.filter { info => + removeFromDriver || !info.blockManagerId.isDriver + } + Future.sequence( + requiredBlockManagers.map { bm => + bm.slaveEndpoint.sendWithReply[Int](removeMsg) + }.toSeq + ) + } + + private def removeBlockManager(blockManagerId: BlockManagerId) { + val info = blockManagerInfo(blockManagerId) + + // Remove the block manager from blockManagerIdByExecutor. + blockManagerIdByExecutor -= blockManagerId.executorId + + // Remove it from blockManagerInfo and remove all the blocks. + blockManagerInfo.remove(blockManagerId) + val iterator = info.blocks.keySet.iterator + while (iterator.hasNext) { + val blockId = iterator.next + val locations = blockLocations.get(blockId) + locations -= blockManagerId + if (locations.size == 0) { + blockLocations.remove(blockId) + } + } + listenerBus.post(SparkListenerBlockManagerRemoved(System.currentTimeMillis(), blockManagerId)) + logInfo(s"Removing block manager $blockManagerId") + } + + private def removeExecutor(execId: String) { + logInfo("Trying to remove executor " + execId + " from BlockManagerMaster.") + blockManagerIdByExecutor.get(execId).foreach(removeBlockManager) + } + + /** + * Return true if the driver knows about the given block manager. Otherwise, return false, + * indicating that the block manager should re-register. + */ + private def heartbeatReceived(blockManagerId: BlockManagerId): Boolean = { + if (!blockManagerInfo.contains(blockManagerId)) { + blockManagerId.isDriver && !isLocal + } else { + blockManagerInfo(blockManagerId).updateLastSeenMs() + true + } + } + + // Remove a block from the slaves that have it. This can only be used to remove + // blocks that the master knows about. + private def removeBlockFromWorkers(blockId: BlockId) { + val locations = blockLocations.get(blockId) + if (locations != null) { + locations.foreach { blockManagerId: BlockManagerId => + val blockManager = blockManagerInfo.get(blockManagerId) + if (blockManager.isDefined) { + // Remove the block from the slave's BlockManager. + // Doesn't actually wait for a confirmation and the message might get lost. + // If message loss becomes frequent, we should add retry logic here. + blockManager.get.slaveEndpoint.sendWithReply[Boolean](RemoveBlock(blockId)) + } + } + } + } + + // Return a map from the block manager id to max memory and remaining memory. + private def memoryStatus: Map[BlockManagerId, (Long, Long)] = { + blockManagerInfo.map { case(blockManagerId, info) => + (blockManagerId, (info.maxMem, info.remainingMem)) + }.toMap + } + + private def storageStatus: Array[StorageStatus] = { + blockManagerInfo.map { case (blockManagerId, info) => + new StorageStatus(blockManagerId, info.maxMem, info.blocks) + }.toArray + } + + /** + * Return the block's status for all block managers, if any. NOTE: This is a + * potentially expensive operation and should only be used for testing. + * + * If askSlaves is true, the master queries each block manager for the most updated block + * statuses. This is useful when the master is not informed of the given block by all block + * managers. + */ + private def blockStatus( + blockId: BlockId, + askSlaves: Boolean): Map[BlockManagerId, Future[Option[BlockStatus]]] = { + val getBlockStatus = GetBlockStatus(blockId) + /* + * Rather than blocking on the block status query, master endpoint should simply return + * Futures to avoid potential deadlocks. This can arise if there exists a block manager + * that is also waiting for this master endpoint's response to a previous message. + */ + blockManagerInfo.values.map { info => + val blockStatusFuture = + if (askSlaves) { + info.slaveEndpoint.sendWithReply[Option[BlockStatus]](getBlockStatus) + } else { + Future { info.getStatus(blockId) } + } + (info.blockManagerId, blockStatusFuture) + }.toMap + } + + /** + * Return the ids of blocks present in all the block managers that match the given filter. + * NOTE: This is a potentially expensive operation and should only be used for testing. + * + * If askSlaves is true, the master queries each block manager for the most updated block + * statuses. This is useful when the master is not informed of the given block by all block + * managers. + */ + private def getMatchingBlockIds( + filter: BlockId => Boolean, + askSlaves: Boolean): Future[Seq[BlockId]] = { + val getMatchingBlockIds = GetMatchingBlockIds(filter) + Future.sequence( + blockManagerInfo.values.map { info => + val future = + if (askSlaves) { + info.slaveEndpoint.sendWithReply[Seq[BlockId]](getMatchingBlockIds) + } else { + Future { info.blocks.keys.filter(filter).toSeq } + } + future + } + ).map(_.flatten.toSeq) + } + + private def register(id: BlockManagerId, maxMemSize: Long, slaveEndpoint: RpcEndpointRef) { + val time = System.currentTimeMillis() + if (!blockManagerInfo.contains(id)) { + blockManagerIdByExecutor.get(id.executorId) match { + case Some(oldId) => + // A block manager of the same executor already exists, so remove it (assumed dead) + logError("Got two different block manager registrations on same executor - " + + s" will replace old one $oldId with new one $id") + removeExecutor(id.executorId) + case None => + } + logInfo("Registering block manager %s with %s RAM, %s".format( + id.hostPort, Utils.bytesToString(maxMemSize), id)) + + blockManagerIdByExecutor(id.executorId) = id + + blockManagerInfo(id) = new BlockManagerInfo( + id, System.currentTimeMillis(), maxMemSize, slaveEndpoint) + } + listenerBus.post(SparkListenerBlockManagerAdded(time, id, maxMemSize)) + } + + private def updateBlockInfo( + blockManagerId: BlockManagerId, + blockId: BlockId, + storageLevel: StorageLevel, + memSize: Long, + diskSize: Long, + tachyonSize: Long): Boolean = { + + if (!blockManagerInfo.contains(blockManagerId)) { + if (blockManagerId.isDriver && !isLocal) { + // We intentionally do not register the master (except in local mode), + // so we should not indicate failure. + return true + } else { + return false + } + } + + if (blockId == null) { + blockManagerInfo(blockManagerId).updateLastSeenMs() + return true + } + + blockManagerInfo(blockManagerId).updateBlockInfo( + blockId, storageLevel, memSize, diskSize, tachyonSize) + + var locations: mutable.HashSet[BlockManagerId] = null + if (blockLocations.containsKey(blockId)) { + locations = blockLocations.get(blockId) + } else { + locations = new mutable.HashSet[BlockManagerId] + blockLocations.put(blockId, locations) + } + + if (storageLevel.isValid) { + locations.add(blockManagerId) + } else { + locations.remove(blockManagerId) + } + + // Remove the block from master tracking if it has been removed on all slaves. + if (locations.size == 0) { + blockLocations.remove(blockId) + } + true + } + + private def getLocations(blockId: BlockId): Seq[BlockManagerId] = { + if (blockLocations.containsKey(blockId)) blockLocations.get(blockId).toSeq else Seq.empty + } + + private def getLocationsMultipleBlockIds(blockIds: Array[BlockId]): Seq[Seq[BlockManagerId]] = { + blockIds.map(blockId => getLocations(blockId)) + } + + /** Get the list of the peers of the given block manager */ + private def getPeers(blockManagerId: BlockManagerId): Seq[BlockManagerId] = { + val blockManagerIds = blockManagerInfo.keySet + if (blockManagerIds.contains(blockManagerId)) { + blockManagerIds.filterNot { _.isDriver }.filterNot { _ == blockManagerId }.toSeq + } else { + Seq.empty + } + } + + /** + * Returns the hostname and port of an executor, based on the [[RpcEnv]] address of its + * [[BlockManagerSlaveEndpoint]]. + */ + private def getRpcHostPortForExecutor(executorId: String): Option[(String, Int)] = { + for ( + blockManagerId <- blockManagerIdByExecutor.get(executorId); + info <- blockManagerInfo.get(blockManagerId) + ) yield { + (info.slaveEndpoint.address.host, info.slaveEndpoint.address.port) + } + } + + override def onStop(): Unit = { + askThreadPool.shutdownNow() + } +} + +@DeveloperApi +case class BlockStatus( + storageLevel: StorageLevel, + memSize: Long, + diskSize: Long, + tachyonSize: Long) { + def isCached: Boolean = memSize + diskSize + tachyonSize > 0 +} + +@DeveloperApi +object BlockStatus { + def empty: BlockStatus = BlockStatus(StorageLevel.NONE, 0L, 0L, 0L) +} + +private[spark] class BlockManagerInfo( + val blockManagerId: BlockManagerId, + timeMs: Long, + val maxMem: Long, + val slaveEndpoint: RpcEndpointRef) + extends Logging { + + private var _lastSeenMs: Long = timeMs + private var _remainingMem: Long = maxMem + + // Mapping from block id to its status. + private val _blocks = new JHashMap[BlockId, BlockStatus] + + def getStatus(blockId: BlockId): Option[BlockStatus] = Option(_blocks.get(blockId)) + + def updateLastSeenMs() { + _lastSeenMs = System.currentTimeMillis() + } + + def updateBlockInfo( + blockId: BlockId, + storageLevel: StorageLevel, + memSize: Long, + diskSize: Long, + tachyonSize: Long) { + + updateLastSeenMs() + + if (_blocks.containsKey(blockId)) { + // The block exists on the slave already. + val blockStatus: BlockStatus = _blocks.get(blockId) + val originalLevel: StorageLevel = blockStatus.storageLevel + val originalMemSize: Long = blockStatus.memSize + + if (originalLevel.useMemory) { + _remainingMem += originalMemSize + } + } + + if (storageLevel.isValid) { + /* isValid means it is either stored in-memory, on-disk or on-Tachyon. + * The memSize here indicates the data size in or dropped from memory, + * tachyonSize here indicates the data size in or dropped from Tachyon, + * and the diskSize here indicates the data size in or dropped to disk. + * They can be both larger than 0, when a block is dropped from memory to disk. + * Therefore, a safe way to set BlockStatus is to set its info in accurate modes. */ + if (storageLevel.useMemory) { + _blocks.put(blockId, BlockStatus(storageLevel, memSize, 0, 0)) + _remainingMem -= memSize + logInfo("Added %s in memory on %s (size: %s, free: %s)".format( + blockId, blockManagerId.hostPort, Utils.bytesToString(memSize), + Utils.bytesToString(_remainingMem))) + } + if (storageLevel.useDisk) { + _blocks.put(blockId, BlockStatus(storageLevel, 0, diskSize, 0)) + logInfo("Added %s on disk on %s (size: %s)".format( + blockId, blockManagerId.hostPort, Utils.bytesToString(diskSize))) + } + if (storageLevel.useOffHeap) { + _blocks.put(blockId, BlockStatus(storageLevel, 0, 0, tachyonSize)) + logInfo("Added %s on tachyon on %s (size: %s)".format( + blockId, blockManagerId.hostPort, Utils.bytesToString(tachyonSize))) + } + } else if (_blocks.containsKey(blockId)) { + // If isValid is not true, drop the block. + val blockStatus: BlockStatus = _blocks.get(blockId) + _blocks.remove(blockId) + if (blockStatus.storageLevel.useMemory) { + logInfo("Removed %s on %s in memory (size: %s, free: %s)".format( + blockId, blockManagerId.hostPort, Utils.bytesToString(blockStatus.memSize), + Utils.bytesToString(_remainingMem))) + } + if (blockStatus.storageLevel.useDisk) { + logInfo("Removed %s on %s on disk (size: %s)".format( + blockId, blockManagerId.hostPort, Utils.bytesToString(blockStatus.diskSize))) + } + if (blockStatus.storageLevel.useOffHeap) { + logInfo("Removed %s on %s on tachyon (size: %s)".format( + blockId, blockManagerId.hostPort, Utils.bytesToString(blockStatus.tachyonSize))) + } + } + } + + def removeBlock(blockId: BlockId) { + if (_blocks.containsKey(blockId)) { + _remainingMem += _blocks.get(blockId).memSize + _blocks.remove(blockId) + } + } + + def remainingMem: Long = _remainingMem + + def lastSeenMs: Long = _lastSeenMs + + def blocks: JHashMap[BlockId, BlockStatus] = _blocks + + override def toString: String = "BlockManagerInfo " + timeMs + " " + _remainingMem + + def clear() { + _blocks.clear() + } +} http://git-wip-us.apache.org/repos/asf/spark/blob/f15806a8/core/src/main/scala/org/apache/spark/storage/BlockManagerMessages.scala ---------------------------------------------------------------------- diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerMessages.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerMessages.scala index 4824745..f89d8d7 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManagerMessages.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerMessages.scala @@ -19,8 +19,7 @@ package org.apache.spark.storage import java.io.{Externalizable, ObjectInput, ObjectOutput} -import akka.actor.ActorRef - +import org.apache.spark.rpc.RpcEndpointRef import org.apache.spark.util.Utils private[spark] object BlockManagerMessages { @@ -52,7 +51,7 @@ private[spark] object BlockManagerMessages { case class RegisterBlockManager( blockManagerId: BlockManagerId, maxMemSize: Long, - sender: ActorRef) + sender: RpcEndpointRef) extends ToBlockManagerMaster case class UpdateBlockInfo( @@ -92,7 +91,7 @@ private[spark] object BlockManagerMessages { case class GetPeers(blockManagerId: BlockManagerId) extends ToBlockManagerMaster - case class GetActorSystemHostPortForExecutor(executorId: String) extends ToBlockManagerMaster + case class GetRpcHostPortForExecutor(executorId: String) extends ToBlockManagerMaster case class RemoveExecutor(execId: String) extends ToBlockManagerMaster http://git-wip-us.apache.org/repos/asf/spark/blob/f15806a8/core/src/main/scala/org/apache/spark/storage/BlockManagerSlaveActor.scala ---------------------------------------------------------------------- diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerSlaveActor.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerSlaveActor.scala deleted file mode 100644 index 52fb896..0000000 --- a/core/src/main/scala/org/apache/spark/storage/BlockManagerSlaveActor.scala +++ /dev/null @@ -1,88 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.storage - -import scala.concurrent.Future - -import akka.actor.{ActorRef, Actor} - -import org.apache.spark.{Logging, MapOutputTracker, SparkEnv} -import org.apache.spark.storage.BlockManagerMessages._ -import org.apache.spark.util.ActorLogReceive - -/** - * An actor to take commands from the master to execute options. For example, - * this is used to remove blocks from the slave's BlockManager. - */ -private[storage] -class BlockManagerSlaveActor( - blockManager: BlockManager, - mapOutputTracker: MapOutputTracker) - extends Actor with ActorLogReceive with Logging { - - import context.dispatcher - - // Operations that involve removing blocks may be slow and should be done asynchronously - override def receiveWithLogging: PartialFunction[Any, Unit] = { - case RemoveBlock(blockId) => - doAsync[Boolean]("removing block " + blockId, sender) { - blockManager.removeBlock(blockId) - true - } - - case RemoveRdd(rddId) => - doAsync[Int]("removing RDD " + rddId, sender) { - blockManager.removeRdd(rddId) - } - - case RemoveShuffle(shuffleId) => - doAsync[Boolean]("removing shuffle " + shuffleId, sender) { - if (mapOutputTracker != null) { - mapOutputTracker.unregisterShuffle(shuffleId) - } - SparkEnv.get.shuffleManager.unregisterShuffle(shuffleId) - } - - case RemoveBroadcast(broadcastId, _) => - doAsync[Int]("removing broadcast " + broadcastId, sender) { - blockManager.removeBroadcast(broadcastId, tellMaster = true) - } - - case GetBlockStatus(blockId, _) => - sender ! blockManager.getStatus(blockId) - - case GetMatchingBlockIds(filter, _) => - sender ! blockManager.getMatchingBlockIds(filter) - } - - private def doAsync[T](actionMessage: String, responseActor: ActorRef)(body: => T) { - val future = Future { - logDebug(actionMessage) - body - } - future.onSuccess { case response => - logDebug("Done " + actionMessage + ", response is " + response) - responseActor ! response - logDebug("Sent response: " + response + " to " + responseActor) - } - future.onFailure { case t: Throwable => - logError("Error in " + actionMessage, t) - responseActor ! null.asInstanceOf[T] - } - } -} http://git-wip-us.apache.org/repos/asf/spark/blob/f15806a8/core/src/main/scala/org/apache/spark/storage/BlockManagerSlaveEndpoint.scala ---------------------------------------------------------------------- diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerSlaveEndpoint.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerSlaveEndpoint.scala new file mode 100644 index 0000000..8980fa8 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerSlaveEndpoint.scala @@ -0,0 +1,94 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.storage + +import scala.concurrent.{ExecutionContext, Future} + +import org.apache.spark.rpc.{RpcEnv, RpcCallContext, RpcEndpoint} +import org.apache.spark.util.Utils +import org.apache.spark.{Logging, MapOutputTracker, SparkEnv} +import org.apache.spark.storage.BlockManagerMessages._ + +/** + * An RpcEndpoint to take commands from the master to execute options. For example, + * this is used to remove blocks from the slave's BlockManager. + */ +private[storage] +class BlockManagerSlaveEndpoint( + override val rpcEnv: RpcEnv, + blockManager: BlockManager, + mapOutputTracker: MapOutputTracker) + extends RpcEndpoint with Logging { + + private val asyncThreadPool = + Utils.newDaemonCachedThreadPool("block-manager-slave-async-thread-pool") + private implicit val asyncExecutionContext = ExecutionContext.fromExecutorService(asyncThreadPool) + + // Operations that involve removing blocks may be slow and should be done asynchronously + override def receiveAndReply(context: RpcCallContext): PartialFunction[Any, Unit] = { + case RemoveBlock(blockId) => + doAsync[Boolean]("removing block " + blockId, context) { + blockManager.removeBlock(blockId) + true + } + + case RemoveRdd(rddId) => + doAsync[Int]("removing RDD " + rddId, context) { + blockManager.removeRdd(rddId) + } + + case RemoveShuffle(shuffleId) => + doAsync[Boolean]("removing shuffle " + shuffleId, context) { + if (mapOutputTracker != null) { + mapOutputTracker.unregisterShuffle(shuffleId) + } + SparkEnv.get.shuffleManager.unregisterShuffle(shuffleId) + } + + case RemoveBroadcast(broadcastId, _) => + doAsync[Int]("removing broadcast " + broadcastId, context) { + blockManager.removeBroadcast(broadcastId, tellMaster = true) + } + + case GetBlockStatus(blockId, _) => + context.reply(blockManager.getStatus(blockId)) + + case GetMatchingBlockIds(filter, _) => + context.reply(blockManager.getMatchingBlockIds(filter)) + } + + private def doAsync[T](actionMessage: String, context: RpcCallContext)(body: => T) { + val future = Future { + logDebug(actionMessage) + body + } + future.onSuccess { case response => + logDebug("Done " + actionMessage + ", response is " + response) + context.reply(response) + logDebug("Sent response: " + response + " to " + context.sender) + } + future.onFailure { case t: Throwable => + logError("Error in " + actionMessage, t) + context.sendFailure(t) + } + } + + override def onStop(): Unit = { + asyncThreadPool.shutdownNow() + } +} http://git-wip-us.apache.org/repos/asf/spark/blob/f15806a8/core/src/main/scala/org/apache/spark/util/Utils.scala ---------------------------------------------------------------------- diff --git a/core/src/main/scala/org/apache/spark/util/Utils.scala b/core/src/main/scala/org/apache/spark/util/Utils.scala index 7c85e28..0fdfaf3 100644 --- a/core/src/main/scala/org/apache/spark/util/Utils.scala +++ b/core/src/main/scala/org/apache/spark/util/Utils.scala @@ -1214,6 +1214,16 @@ private[spark] object Utils extends Logging { } } + /** Executes the given block. Log non-fatal errors if any, and only throw fatal errors */ + def tryLogNonFatalError(block: => Unit) { + try { + block + } catch { + case NonFatal(t) => + logError(s"Uncaught exception in thread ${Thread.currentThread().getName}", t) + } + } + /** * Execute a block of code, then a finally block, but if exceptions happen in * the finally block, do not suppress the original exception. http://git-wip-us.apache.org/repos/asf/spark/blob/f15806a8/core/src/test/scala/org/apache/spark/HeartbeatReceiverSuite.scala ---------------------------------------------------------------------- diff --git a/core/src/test/scala/org/apache/spark/HeartbeatReceiverSuite.scala b/core/src/test/scala/org/apache/spark/HeartbeatReceiverSuite.scala new file mode 100644 index 0000000..0fd570e --- /dev/null +++ b/core/src/test/scala/org/apache/spark/HeartbeatReceiverSuite.scala @@ -0,0 +1,81 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark + +import scala.concurrent.duration._ +import scala.language.postfixOps + +import org.apache.spark.executor.TaskMetrics +import org.apache.spark.storage.BlockManagerId +import org.scalatest.FunSuite +import org.mockito.Mockito.{mock, spy, verify, when} +import org.mockito.Matchers +import org.mockito.Matchers._ + +import org.apache.spark.scheduler.TaskScheduler +import org.apache.spark.util.RpcUtils +import org.scalatest.concurrent.Eventually._ + +class HeartbeatReceiverSuite extends FunSuite with LocalSparkContext { + + test("HeartbeatReceiver") { + sc = spy(new SparkContext("local[2]", "test")) + val scheduler = mock(classOf[TaskScheduler]) + when(scheduler.executorHeartbeatReceived(any(), any(), any())).thenReturn(true) + when(sc.taskScheduler).thenReturn(scheduler) + + val heartbeatReceiver = new HeartbeatReceiver(sc) + sc.env.rpcEnv.setupEndpoint("heartbeat", heartbeatReceiver).send(TaskSchedulerIsSet) + eventually(timeout(5 seconds), interval(5 millis)) { + assert(heartbeatReceiver.scheduler != null) + } + val receiverRef = RpcUtils.makeDriverRef("heartbeat", sc.conf, sc.env.rpcEnv) + + val metrics = new TaskMetrics + val blockManagerId = BlockManagerId("executor-1", "localhost", 12345) + val response = receiverRef.askWithReply[HeartbeatResponse]( + Heartbeat("executor-1", Array(1L -> metrics), blockManagerId)) + + verify(scheduler).executorHeartbeatReceived( + Matchers.eq("executor-1"), Matchers.eq(Array(1L -> metrics)), Matchers.eq(blockManagerId)) + assert(false === response.reregisterBlockManager) + } + + test("HeartbeatReceiver re-register") { + sc = spy(new SparkContext("local[2]", "test")) + val scheduler = mock(classOf[TaskScheduler]) + when(scheduler.executorHeartbeatReceived(any(), any(), any())).thenReturn(false) + when(sc.taskScheduler).thenReturn(scheduler) + + val heartbeatReceiver = new HeartbeatReceiver(sc) + sc.env.rpcEnv.setupEndpoint("heartbeat", heartbeatReceiver).send(TaskSchedulerIsSet) + eventually(timeout(5 seconds), interval(5 millis)) { + assert(heartbeatReceiver.scheduler != null) + } + val receiverRef = RpcUtils.makeDriverRef("heartbeat", sc.conf, sc.env.rpcEnv) + + val metrics = new TaskMetrics + val blockManagerId = BlockManagerId("executor-1", "localhost", 12345) + val response = receiverRef.askWithReply[HeartbeatResponse]( + Heartbeat("executor-1", Array(1L -> metrics), blockManagerId)) + + verify(scheduler).executorHeartbeatReceived( + Matchers.eq("executor-1"), Matchers.eq(Array(1L -> metrics)), Matchers.eq(blockManagerId)) + assert(true === response.reregisterBlockManager) + } +} http://git-wip-us.apache.org/repos/asf/spark/blob/f15806a8/core/src/test/scala/org/apache/spark/rpc/RpcEnvSuite.scala ---------------------------------------------------------------------- diff --git a/core/src/test/scala/org/apache/spark/rpc/RpcEnvSuite.scala b/core/src/test/scala/org/apache/spark/rpc/RpcEnvSuite.scala index e07bdb9..4f19c4f 100644 --- a/core/src/test/scala/org/apache/spark/rpc/RpcEnvSuite.scala +++ b/core/src/test/scala/org/apache/spark/rpc/RpcEnvSuite.scala @@ -311,7 +311,7 @@ abstract class RpcEnvSuite extends FunSuite with BeforeAndAfterAll { } test("self: call in onStop") { - @volatile var e: Throwable = null + @volatile var selfOption: Option[RpcEndpointRef] = null val endpointRef = env.setupEndpoint("self-onStop", new RpcEndpoint { override val rpcEnv = env @@ -321,20 +321,18 @@ abstract class RpcEnvSuite extends FunSuite with BeforeAndAfterAll { } override def onStop(): Unit = { - self + selfOption = Option(self) } override def onError(cause: Throwable): Unit = { - e = cause } }) env.stop(endpointRef) eventually(timeout(5 seconds), interval(10 millis)) { - // Calling `self` in `onStop` is invalid - assert(e != null) - assert(e.getMessage.contains("Cannot find RpcEndpointRef")) + // Calling `self` in `onStop` will return null, so selfOption will be None + assert(selfOption == None) } } @@ -342,7 +340,7 @@ abstract class RpcEnvSuite extends FunSuite with BeforeAndAfterAll { // If a RpcEnv implementation breaks the `receive` contract, hope this test can expose it for(i <- 0 until 100) { @volatile var result = 0 - val endpointRef = env.setupThreadSafeEndpoint(s"receive-in-sequence-$i", new RpcEndpoint { + val endpointRef = env.setupEndpoint(s"receive-in-sequence-$i", new ThreadSafeRpcEndpoint { override val rpcEnv = env override def receive = { @@ -475,7 +473,7 @@ abstract class RpcEnvSuite extends FunSuite with BeforeAndAfterAll { test("network events") { val events = new mutable.ArrayBuffer[(Any, Any)] with mutable.SynchronizedBuffer[(Any, Any)] - env.setupThreadSafeEndpoint("network-events", new RpcEndpoint { + env.setupEndpoint("network-events", new ThreadSafeRpcEndpoint { override val rpcEnv = env override def receive = { http://git-wip-us.apache.org/repos/asf/spark/blob/f15806a8/core/src/test/scala/org/apache/spark/storage/BlockManagerReplicationSuite.scala ---------------------------------------------------------------------- diff --git a/core/src/test/scala/org/apache/spark/storage/BlockManagerReplicationSuite.scala b/core/src/test/scala/org/apache/spark/storage/BlockManagerReplicationSuite.scala index c2903c8..b4de90b 100644 --- a/core/src/test/scala/org/apache/spark/storage/BlockManagerReplicationSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/BlockManagerReplicationSuite.scala @@ -22,11 +22,11 @@ import scala.concurrent.duration._ import scala.language.implicitConversions import scala.language.postfixOps -import akka.actor.{ActorSystem, Props} import org.mockito.Mockito.{mock, when} -import org.scalatest.{BeforeAndAfter, FunSuite, Matchers, PrivateMethodTester} +import org.scalatest.{BeforeAndAfter, FunSuite, Matchers} import org.scalatest.concurrent.Eventually._ +import org.apache.spark.rpc.RpcEnv import org.apache.spark.{MapOutputTrackerMaster, SparkConf, SparkContext, SecurityManager} import org.apache.spark.network.BlockTransferService import org.apache.spark.network.nio.NioBlockTransferService @@ -34,13 +34,12 @@ import org.apache.spark.scheduler.LiveListenerBus import org.apache.spark.serializer.KryoSerializer import org.apache.spark.shuffle.hash.HashShuffleManager import org.apache.spark.storage.StorageLevel._ -import org.apache.spark.util.{AkkaUtils, SizeEstimator} /** Testsuite that tests block replication in BlockManager */ class BlockManagerReplicationSuite extends FunSuite with Matchers with BeforeAndAfter { private val conf = new SparkConf(false) - var actorSystem: ActorSystem = null + var rpcEnv: RpcEnv = null var master: BlockManagerMaster = null val securityMgr = new SecurityManager(conf) val mapOutputTracker = new MapOutputTrackerMaster(conf) @@ -61,7 +60,7 @@ class BlockManagerReplicationSuite extends FunSuite with Matchers with BeforeAnd maxMem: Long, name: String = SparkContext.DRIVER_IDENTIFIER): BlockManager = { val transfer = new NioBlockTransferService(conf, securityMgr) - val store = new BlockManager(name, actorSystem, master, serializer, maxMem, conf, + val store = new BlockManager(name, rpcEnv, master, serializer, maxMem, conf, mapOutputTracker, shuffleManager, transfer, securityMgr, 0) store.initialize("app-id") allStores += store @@ -69,12 +68,10 @@ class BlockManagerReplicationSuite extends FunSuite with Matchers with BeforeAnd } before { - val (actorSystem, boundPort) = AkkaUtils.createActorSystem( - "test", "localhost", 0, conf = conf, securityManager = securityMgr) - this.actorSystem = actorSystem + rpcEnv = RpcEnv.create("test", "localhost", 0, conf, securityMgr) conf.set("spark.authenticate", "false") - conf.set("spark.driver.port", boundPort.toString) + conf.set("spark.driver.port", rpcEnv.address.port.toString) conf.set("spark.storage.unrollFraction", "0.4") conf.set("spark.storage.unrollMemoryThreshold", "512") @@ -83,18 +80,17 @@ class BlockManagerReplicationSuite extends FunSuite with Matchers with BeforeAnd // to make cached peers refresh frequently conf.set("spark.storage.cachedPeersTtl", "10") - master = new BlockManagerMaster( - actorSystem.actorOf(Props(new BlockManagerMasterActor(true, conf, new LiveListenerBus))), - conf, true) + master = new BlockManagerMaster(rpcEnv.setupEndpoint("blockmanager", + new BlockManagerMasterEndpoint(rpcEnv, true, conf, new LiveListenerBus)), conf, true) allStores.clear() } after { allStores.foreach { _.stop() } allStores.clear() - actorSystem.shutdown() - actorSystem.awaitTermination() - actorSystem = null + rpcEnv.shutdown() + rpcEnv.awaitTermination() + rpcEnv = null master = null } @@ -262,7 +258,7 @@ class BlockManagerReplicationSuite extends FunSuite with Matchers with BeforeAnd val failableTransfer = mock(classOf[BlockTransferService]) // this wont actually work when(failableTransfer.hostName).thenReturn("some-hostname") when(failableTransfer.port).thenReturn(1000) - val failableStore = new BlockManager("failable-store", actorSystem, master, serializer, + val failableStore = new BlockManager("failable-store", rpcEnv, master, serializer, 10000, conf, mapOutputTracker, shuffleManager, failableTransfer, securityMgr, 0) failableStore.initialize("app-id") allStores += failableStore // so that this gets stopped after test http://git-wip-us.apache.org/repos/asf/spark/blob/f15806a8/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala ---------------------------------------------------------------------- diff --git a/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala b/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala index ecd1cba..283090e 100644 --- a/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala @@ -19,24 +19,18 @@ package org.apache.spark.storage import java.nio.{ByteBuffer, MappedByteBuffer} import java.util.Arrays -import java.util.concurrent.TimeUnit import scala.collection.mutable.ArrayBuffer -import scala.concurrent.Await import scala.concurrent.duration._ import scala.language.implicitConversions import scala.language.postfixOps -import akka.actor._ -import akka.pattern.ask -import akka.util.Timeout - import org.mockito.Mockito.{mock, when} - import org.scalatest._ import org.scalatest.concurrent.Eventually._ import org.scalatest.concurrent.Timeouts._ +import org.apache.spark.rpc.RpcEnv import org.apache.spark.{MapOutputTrackerMaster, SparkConf, SparkContext, SecurityManager} import org.apache.spark.executor.DataReadMethod import org.apache.spark.network.nio.NioBlockTransferService @@ -53,7 +47,7 @@ class BlockManagerSuite extends FunSuite with Matchers with BeforeAndAfterEach private val conf = new SparkConf(false) var store: BlockManager = null var store2: BlockManager = null - var actorSystem: ActorSystem = null + var rpcEnv: RpcEnv = null var master: BlockManagerMaster = null conf.set("spark.authenticate", "false") val securityMgr = new SecurityManager(conf) @@ -72,28 +66,25 @@ class BlockManagerSuite extends FunSuite with Matchers with BeforeAndAfterEach maxMem: Long, name: String = SparkContext.DRIVER_IDENTIFIER): BlockManager = { val transfer = new NioBlockTransferService(conf, securityMgr) - val manager = new BlockManager(name, actorSystem, master, serializer, maxMem, conf, + val manager = new BlockManager(name, rpcEnv, master, serializer, maxMem, conf, mapOutputTracker, shuffleManager, transfer, securityMgr, 0) manager.initialize("app-id") manager } override def beforeEach(): Unit = { - val (actorSystem, boundPort) = AkkaUtils.createActorSystem( - "test", "localhost", 0, conf = conf, securityManager = securityMgr) - this.actorSystem = actorSystem + rpcEnv = RpcEnv.create("test", "localhost", 0, conf, securityMgr) // Set the arch to 64-bit and compressedOops to true to get a deterministic test-case System.setProperty("os.arch", "amd64") conf.set("os.arch", "amd64") conf.set("spark.test.useCompressedOops", "true") - conf.set("spark.driver.port", boundPort.toString) + conf.set("spark.driver.port", rpcEnv.address.port.toString) conf.set("spark.storage.unrollFraction", "0.4") conf.set("spark.storage.unrollMemoryThreshold", "512") - master = new BlockManagerMaster( - actorSystem.actorOf(Props(new BlockManagerMasterActor(true, conf, new LiveListenerBus))), - conf, true) + master = new BlockManagerMaster(rpcEnv.setupEndpoint("blockmanager", + new BlockManagerMasterEndpoint(rpcEnv, true, conf, new LiveListenerBus)), conf, true) val initialize = PrivateMethod[Unit]('initialize) SizeEstimator invokePrivate initialize() @@ -108,9 +99,9 @@ class BlockManagerSuite extends FunSuite with Matchers with BeforeAndAfterEach store2.stop() store2 = null } - actorSystem.shutdown() - actorSystem.awaitTermination() - actorSystem = null + rpcEnv.shutdown() + rpcEnv.awaitTermination() + rpcEnv = null master = null } @@ -357,10 +348,8 @@ class BlockManagerSuite extends FunSuite with Matchers with BeforeAndAfterEach master.removeExecutor(store.blockManagerId.executorId) assert(master.getLocations("a1").size == 0, "a1 was not removed from master") - implicit val timeout = Timeout(30, TimeUnit.SECONDS) - val reregister = !Await.result( - master.driverActor ? BlockManagerHeartbeat(store.blockManagerId), - timeout.duration).asInstanceOf[Boolean] + val reregister = !master.driverEndpoint.askWithReply[Boolean]( + BlockManagerHeartbeat(store.blockManagerId)) assert(reregister == true) } @@ -785,7 +774,7 @@ class BlockManagerSuite extends FunSuite with Matchers with BeforeAndAfterEach test("block store put failure") { // Use Java serializer so we can create an unserializable error. val transfer = new NioBlockTransferService(conf, securityMgr) - store = new BlockManager(SparkContext.DRIVER_IDENTIFIER, actorSystem, master, + store = new BlockManager(SparkContext.DRIVER_IDENTIFIER, rpcEnv, master, new JavaSerializer(conf), 1200, conf, mapOutputTracker, shuffleManager, transfer, securityMgr, 0) http://git-wip-us.apache.org/repos/asf/spark/blob/f15806a8/streaming/src/test/scala/org/apache/spark/streaming/ReceivedBlockHandlerSuite.scala ---------------------------------------------------------------------- diff --git a/streaming/src/test/scala/org/apache/spark/streaming/ReceivedBlockHandlerSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/ReceivedBlockHandlerSuite.scala index 18a477f..ef4873d 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/ReceivedBlockHandlerSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/ReceivedBlockHandlerSuite.scala @@ -24,20 +24,20 @@ import scala.collection.mutable.ArrayBuffer import scala.concurrent.duration._ import scala.language.postfixOps -import akka.actor.{ActorSystem, Props} import org.apache.hadoop.conf.Configuration import org.scalatest.{BeforeAndAfter, FunSuite, Matchers} import org.scalatest.concurrent.Eventually._ import org.apache.spark._ import org.apache.spark.network.nio.NioBlockTransferService +import org.apache.spark.rpc.RpcEnv import org.apache.spark.scheduler.LiveListenerBus import org.apache.spark.serializer.KryoSerializer import org.apache.spark.shuffle.hash.HashShuffleManager import org.apache.spark.storage._ import org.apache.spark.streaming.receiver._ import org.apache.spark.streaming.util._ -import org.apache.spark.util.{AkkaUtils, ManualClock, Utils} +import org.apache.spark.util.{ManualClock, Utils} import WriteAheadLogBasedBlockHandler._ import WriteAheadLogSuite._ @@ -54,22 +54,19 @@ class ReceivedBlockHandlerSuite extends FunSuite with BeforeAndAfter with Matche val manualClock = new ManualClock val blockManagerSize = 10000000 - var actorSystem: ActorSystem = null + var rpcEnv: RpcEnv = null var blockManagerMaster: BlockManagerMaster = null var blockManager: BlockManager = null var tempDirectory: File = null before { - val (actorSystem, boundPort) = AkkaUtils.createActorSystem( - "test", "localhost", 0, conf = conf, securityManager = securityMgr) - this.actorSystem = actorSystem - conf.set("spark.driver.port", boundPort.toString) + rpcEnv = RpcEnv.create("test", "localhost", 0, conf, securityMgr) + conf.set("spark.driver.port", rpcEnv.address.port.toString) - blockManagerMaster = new BlockManagerMaster( - actorSystem.actorOf(Props(new BlockManagerMasterActor(true, conf, new LiveListenerBus))), - conf, true) + blockManagerMaster = new BlockManagerMaster(rpcEnv.setupEndpoint("blockmanager", + new BlockManagerMasterEndpoint(rpcEnv, true, conf, new LiveListenerBus)), conf, true) - blockManager = new BlockManager("bm", actorSystem, blockManagerMaster, serializer, + blockManager = new BlockManager("bm", rpcEnv, blockManagerMaster, serializer, blockManagerSize, conf, mapOutputTracker, shuffleManager, new NioBlockTransferService(conf, securityMgr), securityMgr, 0) blockManager.initialize("app-id") @@ -87,9 +84,9 @@ class ReceivedBlockHandlerSuite extends FunSuite with BeforeAndAfter with Matche blockManagerMaster.stop() blockManagerMaster = null } - actorSystem.shutdown() - actorSystem.awaitTermination() - actorSystem = null + rpcEnv.shutdown() + rpcEnv.awaitTermination() + rpcEnv = null Utils.deleteRecursively(tempDirectory) } http://git-wip-us.apache.org/repos/asf/spark/blob/f15806a8/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala ---------------------------------------------------------------------- diff --git a/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala b/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala index 455554e..24a1e02 100644 --- a/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala +++ b/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala @@ -24,22 +24,20 @@ import java.lang.reflect.InvocationTargetException import java.net.{Socket, URL} import java.util.concurrent.atomic.AtomicReference -import akka.actor._ -import akka.remote._ import org.apache.hadoop.fs.{FileSystem, Path} import org.apache.hadoop.util.ShutdownHookManager import org.apache.hadoop.yarn.api._ import org.apache.hadoop.yarn.api.records._ import org.apache.hadoop.yarn.conf.YarnConfiguration +import org.apache.spark.rpc._ import org.apache.spark.{Logging, SecurityManager, SparkConf, SparkContext, SparkEnv} import org.apache.spark.SparkException import org.apache.spark.deploy.{PythonRunner, SparkHadoopUtil} import org.apache.spark.deploy.history.HistoryServer import org.apache.spark.scheduler.cluster.YarnSchedulerBackend import org.apache.spark.scheduler.cluster.CoarseGrainedClusterMessages._ -import org.apache.spark.util.{AkkaUtils, ChildFirstURLClassLoader, MutableURLClassLoader, - SignalLogger, Utils} +import org.apache.spark.util._ /** * Common application master functionality for Spark on Yarn. @@ -72,8 +70,8 @@ private[spark] class ApplicationMaster( @volatile private var allocator: YarnAllocator = _ // Fields used in client mode. - private var actorSystem: ActorSystem = null - private var actor: ActorRef = _ + private var rpcEnv: RpcEnv = null + private var amEndpoint: RpcEndpointRef = _ // Fields used in cluster mode. private val sparkContextRef = new AtomicReference[SparkContext](null) @@ -240,22 +238,21 @@ private[spark] class ApplicationMaster( } /** - * Create an actor that communicates with the driver. + * Create an [[RpcEndpoint]] that communicates with the driver. * * In cluster mode, the AM and the driver belong to same process - * so the AM actor need not monitor lifecycle of the driver. + * so the AMEndpoint need not monitor lifecycle of the driver. */ - private def runAMActor( + private def runAMEndpoint( host: String, port: String, isClusterMode: Boolean): Unit = { - val driverUrl = AkkaUtils.address( - AkkaUtils.protocol(actorSystem), + val driverEndpont = rpcEnv.setupEndpointRef( SparkEnv.driverActorSystemName, - host, - port, - YarnSchedulerBackend.ACTOR_NAME) - actor = actorSystem.actorOf(Props(new AMActor(driverUrl, isClusterMode)), name = "YarnAM") + RpcAddress(host, port.toInt), + YarnSchedulerBackend.ENDPOINT_NAME) + amEndpoint = + rpcEnv.setupEndpoint("YarnAM", new AMEndpoint(rpcEnv, driverEndpont, isClusterMode)) } private def runDriver(securityMgr: SecurityManager): Unit = { @@ -272,8 +269,8 @@ private[spark] class ApplicationMaster( ApplicationMaster.EXIT_SC_NOT_INITED, "Timed out waiting for SparkContext.") } else { - actorSystem = sc.env.actorSystem - runAMActor( + rpcEnv = sc.env.rpcEnv + runAMEndpoint( sc.getConf.get("spark.driver.host"), sc.getConf.get("spark.driver.port"), isClusterMode = true) @@ -283,8 +280,7 @@ private[spark] class ApplicationMaster( } private def runExecutorLauncher(securityMgr: SecurityManager): Unit = { - actorSystem = AkkaUtils.createActorSystem("sparkYarnAM", Utils.localHostName, 0, - conf = sparkConf, securityManager = securityMgr)._1 + rpcEnv = RpcEnv.create("sparkYarnAM", Utils.localHostName, 0, sparkConf, securityMgr) waitForSparkDriver() addAmIpFilter() registerAM(sparkConf.get("spark.driver.appUIAddress", ""), securityMgr) @@ -431,7 +427,7 @@ private[spark] class ApplicationMaster( sparkConf.set("spark.driver.host", driverHost) sparkConf.set("spark.driver.port", driverPort.toString) - runAMActor(driverHost, driverPort.toString, isClusterMode = false) + runAMEndpoint(driverHost, driverPort.toString, isClusterMode = false) } /** Add the Yarn IP filter that is required for properly securing the UI. */ @@ -443,7 +439,7 @@ private[spark] class ApplicationMaster( System.setProperty("spark.ui.filters", amFilter) params.foreach { case (k, v) => System.setProperty(s"spark.$amFilter.param.$k", v) } } else { - actor ! AddWebUIFilter(amFilter, params.toMap, proxyBase) + amEndpoint.send(AddWebUIFilter(amFilter, params.toMap, proxyBase)) } } @@ -505,44 +501,29 @@ private[spark] class ApplicationMaster( } /** - * An actor that communicates with the driver's scheduler backend. + * An [[RpcEndpoint]] that communicates with the driver's scheduler backend. */ - private class AMActor(driverUrl: String, isClusterMode: Boolean) extends Actor { - var driver: ActorSelection = _ - - override def preStart(): Unit = { - logInfo("Listen to driver: " + driverUrl) - driver = context.actorSelection(driverUrl) - // Send a hello message to establish the connection, after which - // we can monitor Lifecycle Events. - driver ! "Hello" - driver ! RegisterClusterManager - // In cluster mode, the AM can directly monitor the driver status instead - // of trying to deduce it from the lifecycle of the driver's actor - if (!isClusterMode) { - context.system.eventStream.subscribe(self, classOf[RemotingLifecycleEvent]) - } + private class AMEndpoint( + override val rpcEnv: RpcEnv, driver: RpcEndpointRef, isClusterMode: Boolean) + extends RpcEndpoint with Logging { + + override def onStart(): Unit = { + driver.send(RegisterClusterManager(self)) } override def receive: PartialFunction[Any, Unit] = { - case x: DisassociatedEvent => - logInfo(s"Driver terminated or disconnected! Shutting down. $x") - // In cluster mode, do not rely on the disassociated event to exit - // This avoids potentially reporting incorrect exit codes if the driver fails - if (!isClusterMode) { - finish(FinalApplicationStatus.SUCCEEDED, ApplicationMaster.EXIT_SUCCESS) - } - case x: AddWebUIFilter => logInfo(s"Add WebUI Filter. $x") - driver ! x + driver.send(x) + } + override def receiveAndReply(context: RpcCallContext): PartialFunction[Any, Unit] = { case RequestExecutors(requestedTotal) => Option(allocator) match { case Some(a) => a.requestTotalExecutors(requestedTotal) case None => logWarning("Container allocator is not ready to request executors yet.") } - sender ! true + context.reply(true) case KillExecutors(executorIds) => logInfo(s"Driver requested to kill executor(s) ${executorIds.mkString(", ")}.") @@ -550,7 +531,16 @@ private[spark] class ApplicationMaster( case Some(a) => executorIds.foreach(a.killExecutor) case None => logWarning("Container allocator is not ready to kill executors yet.") } - sender ! true + context.reply(true) + } + + override def onDisconnected(remoteAddress: RpcAddress): Unit = { + logInfo(s"Driver terminated or disconnected! Shutting down. $remoteAddress") + // In cluster mode, do not rely on the disassociated event to exit + // This avoids potentially reporting incorrect exit codes if the driver fails + if (!isClusterMode) { + finish(FinalApplicationStatus.SUCCEEDED, ApplicationMaster.EXIT_SUCCESS) + } } } http://git-wip-us.apache.org/repos/asf/spark/blob/f15806a8/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocator.scala ---------------------------------------------------------------------- diff --git a/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocator.scala b/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocator.scala index c98763e..b8f42da 100644 --- a/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocator.scala +++ b/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocator.scala @@ -112,7 +112,7 @@ private[yarn] class YarnAllocator( SparkEnv.driverActorSystemName, sparkConf.get("spark.driver.host"), sparkConf.get("spark.driver.port"), - CoarseGrainedSchedulerBackend.ACTOR_NAME) + CoarseGrainedSchedulerBackend.ENDPOINT_NAME) // For testing private val launchContainers = sparkConf.getBoolean("spark.yarn.launchContainers", true) --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org