Github user JoshRosen commented on a diff in the pull request:

    https://github.com/apache/spark/pull/17955#discussion_r116344085
  
    --- Diff: core/src/main/scala/org/apache/spark/MapOutputTracker.scala ---
    @@ -495,106 +532,153 @@ private[spark] class MapOutputTrackerMaster(conf: 
SparkConf,
         None
       }
     
    -  def incrementEpoch() {
    +  private def incrementEpoch() {
         epochLock.synchronized {
           epoch += 1
           logDebug("Increasing epoch to " + epoch)
         }
       }
     
    -  private def removeBroadcast(bcast: Broadcast[_]): Unit = {
    -    if (null != bcast) {
    -      broadcastManager.unbroadcast(bcast.id,
    -        removeFromDriver = true, blocking = false)
    +  /** Called to get current epoch number. */
    +  def getEpoch: Long = {
    +    epochLock.synchronized {
    +      return epoch
         }
       }
     
    -  private def clearCachedBroadcast(): Unit = {
    -    for (cached <- cachedSerializedBroadcast) removeBroadcast(cached._2)
    -    cachedSerializedBroadcast.clear()
    -  }
    -
    -  def getSerializedMapOutputStatuses(shuffleId: Int): Array[Byte] = {
    -    var statuses: Array[MapStatus] = null
    -    var retBytes: Array[Byte] = null
    -    var epochGotten: Long = -1
    -
    -    // Check to see if we have a cached version, returns true if it does
    -    // and has side effect of setting retBytes.  If not returns false
    -    // with side effect of setting statuses
    -    def checkCachedStatuses(): Boolean = {
    -      epochLock.synchronized {
    -        if (epoch > cacheEpoch) {
    -          cachedSerializedStatuses.clear()
    -          clearCachedBroadcast()
    -          cacheEpoch = epoch
    -        }
    -        cachedSerializedStatuses.get(shuffleId) match {
    -          case Some(bytes) =>
    -            retBytes = bytes
    -            true
    -          case None =>
    -            logDebug("cached status not found for : " + shuffleId)
    -            statuses = mapStatuses.getOrElse(shuffleId, 
Array.empty[MapStatus])
    -            epochGotten = epoch
    -            false
    -        }
    -      }
    -    }
    -
    -    if (checkCachedStatuses()) return retBytes
    -    var shuffleIdLock = shuffleIdLocks.get(shuffleId)
    -    if (null == shuffleIdLock) {
    -      val newLock = new Object()
    -      // in general, this condition should be false - but good to be 
paranoid
    -      val prevLock = shuffleIdLocks.putIfAbsent(shuffleId, newLock)
    -      shuffleIdLock = if (null != prevLock) prevLock else newLock
    -    }
    -    // synchronize so we only serialize/broadcast it once since multiple 
threads call
    -    // in parallel
    -    shuffleIdLock.synchronized {
    -      // double check to make sure someone else didn't serialize and cache 
the same
    -      // mapstatus while we were waiting on the synchronize
    -      if (checkCachedStatuses()) return retBytes
    -
    -      // If we got here, we failed to find the serialized locations in the 
cache, so we pulled
    -      // out a snapshot of the locations as "statuses"; let's serialize 
and return that
    -      val (bytes, bcast) = MapOutputTracker.serializeMapStatuses(statuses, 
broadcastManager,
    -        isLocal, minSizeForBroadcast)
    -      logInfo("Size of output statuses for shuffle %d is %d 
bytes".format(shuffleId, bytes.length))
    -      // Add them into the table only if the epoch hasn't changed while we 
were working
    -      epochLock.synchronized {
    -        if (epoch == epochGotten) {
    -          cachedSerializedStatuses(shuffleId) = bytes
    -          if (null != bcast) cachedSerializedBroadcast(shuffleId) = bcast
    -        } else {
    -          logInfo("Epoch changed, not caching!")
    -          removeBroadcast(bcast)
    +  // This method is only called in local-mode.
    +  def getMapSizesByExecutorId(shuffleId: Int, startPartition: Int, 
endPartition: Int)
    +      : Seq[(BlockManagerId, Seq[(BlockId, Long)])] = {
    +    logDebug(s"Fetching outputs for shuffle $shuffleId, partitions 
$startPartition-$endPartition")
    +    shuffleStatuses.get(shuffleId) match {
    +      case Some (shuffleStatus) =>
    +        shuffleStatus.withMapStatuses { statuses =>
    +          MapOutputTracker.convertMapStatuses(shuffleId, startPartition, 
endPartition, statuses)
             }
    -      }
    -      bytes
    +      case None =>
    +        Seq.empty
         }
       }
     
       override def stop() {
         mapOutputRequests.offer(PoisonPill)
         threadpool.shutdown()
         sendTracker(StopMapOutputTracker)
    -    mapStatuses.clear()
         trackerEndpoint = null
    -    cachedSerializedStatuses.clear()
    -    clearCachedBroadcast()
    -    shuffleIdLocks.clear()
    +    shuffleStatuses.clear()
       }
     }
     
     /**
    - * MapOutputTracker for the executors, which fetches map output 
information from the driver's
    - * MapOutputTrackerMaster.
    + * Executor-side client for fetching map output info from the driver's 
MapOutputTrackerMaster.
    + * Note that this is not used in local-mode; instead, local-mode Executors 
access the
    + * MapOutputTrackerMaster directly (which is possible because the master 
and worker share a comon
    + * superclass).
      */
     private[spark] class MapOutputTrackerWorker(conf: SparkConf) extends 
MapOutputTracker(conf) {
    -  protected val mapStatuses: Map[Int, Array[MapStatus]] =
    +
    +  val mapStatuses: Map[Int, Array[MapStatus]] =
         new ConcurrentHashMap[Int, Array[MapStatus]]().asScala
    +
    +  /** Remembers which map output locations are currently being fetched on 
an executor. */
    +  private val fetching = new HashSet[Int]
    +
    +  override def getMapSizesByExecutorId(shuffleId: Int, startPartition: 
Int, endPartition: Int)
    +      : Seq[(BlockManagerId, Seq[(BlockId, Long)])] = {
    +    logDebug(s"Fetching outputs for shuffle $shuffleId, partitions 
$startPartition-$endPartition")
    +    val statuses = getStatuses(shuffleId)
    +    try {
    +      MapOutputTracker.convertMapStatuses(shuffleId, startPartition, 
endPartition, statuses)
    +    } catch {
    +      case e: MetadataFetchFailedException =>
    +        // We experienced a fetch failure so our mapStatuses cache is 
outdated; clear it:
    +        mapStatuses.clear()
    --- End diff --
    
    The idea here is to _locally_ clear the mapStatuses cache. In the old code 
the cache would be indirectly cleared after the FetchFailure is handled by the 
DAGScheduler and the epoch is incremented in the MapOutputTrackerMaster.
    
    The cache clearing still happens when the master sends us a higher epoch, 
but now the driver-side epoch is only bumped after outputs are actually lost 
(or missing outputs become available), not after every fetch failure.


---
If your project is set up for it, you can reply to this email and have your
reply appear on GitHub as well. If your project does not have this feature
enabled and wishes so, or if the feature is enabled but not working, please
contact infrastructure at infrastruct...@apache.org or file a JIRA ticket
with INFRA.
---

---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org
For additional commands, e-mail: reviews-h...@spark.apache.org

Reply via email to