Repository: spark
Updated Branches:
  refs/heads/master f17d43b03 -> 0745a305f


Tighten up field/method visibility in Executor and made some code more clear to 
read.

I was reading Executor just now and found that some latest changes introduced 
some weird code path with too much monadic chaining and unnecessary fields. I 
cleaned it up a bit, and also tightened up the visibility of various 
fields/methods. Also added some inline documentation to help understand this 
code better.

Author: Reynold Xin <r...@databricks.com>

Closes #4850 from rxin/executor and squashes the following commits:

866fc60 [Reynold Xin] Code review feedback.
020efbb [Reynold Xin] Tighten up field/method visibility in Executor and made 
some code more clear to read.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/0745a305
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/0745a305
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/0745a305

Branch: refs/heads/master
Commit: 0745a305fac622a6eeb8aa4a7401205a14252939
Parents: f17d43b
Author: Reynold Xin <r...@databricks.com>
Authored: Thu Mar 19 22:12:01 2015 -0400
Committer: Reynold Xin <r...@databricks.com>
Committed: Thu Mar 19 22:12:01 2015 -0400

----------------------------------------------------------------------
 .../scala/org/apache/spark/TaskEndReason.scala  |   6 +-
 .../spark/executor/CommitDeniedException.scala  |   6 +-
 .../org/apache/spark/executor/Executor.scala    | 196 ++++++++++---------
 .../apache/spark/executor/ExecutorSource.scala  |  16 +-
 .../scala/org/apache/spark/scheduler/Task.scala |   2 +-
 5 files changed, 120 insertions(+), 106 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/0745a305/core/src/main/scala/org/apache/spark/TaskEndReason.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/TaskEndReason.scala 
b/core/src/main/scala/org/apache/spark/TaskEndReason.scala
index 29a5cd5..48fd3e7 100644
--- a/core/src/main/scala/org/apache/spark/TaskEndReason.scala
+++ b/core/src/main/scala/org/apache/spark/TaskEndReason.scala
@@ -151,11 +151,7 @@ case object TaskKilled extends TaskFailedReason {
  * Task requested the driver to commit, but was denied.
  */
 @DeveloperApi
-case class TaskCommitDenied(
-    jobID: Int,
-    partitionID: Int,
-    attemptID: Int)
-  extends TaskFailedReason {
+case class TaskCommitDenied(jobID: Int, partitionID: Int, attemptID: Int) 
extends TaskFailedReason {
   override def toErrorString: String = s"TaskCommitDenied (Driver denied task 
commit)" +
     s" for job: $jobID, partition: $partitionID, attempt: $attemptID"
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/0745a305/core/src/main/scala/org/apache/spark/executor/CommitDeniedException.scala
----------------------------------------------------------------------
diff --git 
a/core/src/main/scala/org/apache/spark/executor/CommitDeniedException.scala 
b/core/src/main/scala/org/apache/spark/executor/CommitDeniedException.scala
index f7604a3..f47d7ef 100644
--- a/core/src/main/scala/org/apache/spark/executor/CommitDeniedException.scala
+++ b/core/src/main/scala/org/apache/spark/executor/CommitDeniedException.scala
@@ -22,14 +22,12 @@ import org.apache.spark.{TaskCommitDenied, TaskEndReason}
 /**
  * Exception thrown when a task attempts to commit output to HDFS but is 
denied by the driver.
  */
-class CommitDeniedException(
+private[spark] class CommitDeniedException(
     msg: String,
     jobID: Int,
     splitID: Int,
     attemptID: Int)
   extends Exception(msg) {
 
-  def toTaskEndReason: TaskEndReason = new TaskCommitDenied(jobID, splitID, 
attemptID)
-
+  def toTaskEndReason: TaskEndReason = TaskCommitDenied(jobID, splitID, 
attemptID)
 }
-

http://git-wip-us.apache.org/repos/asf/spark/blob/0745a305/core/src/main/scala/org/apache/spark/executor/Executor.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/executor/Executor.scala 
b/core/src/main/scala/org/apache/spark/executor/Executor.scala
index 6196f7b..bf3135e 100644
--- a/core/src/main/scala/org/apache/spark/executor/Executor.scala
+++ b/core/src/main/scala/org/apache/spark/executor/Executor.scala
@@ -21,7 +21,7 @@ import java.io.File
 import java.lang.management.ManagementFactory
 import java.net.URL
 import java.nio.ByteBuffer
-import java.util.concurrent._
+import java.util.concurrent.ConcurrentHashMap
 
 import scala.collection.JavaConversions._
 import scala.collection.mutable.{ArrayBuffer, HashMap}
@@ -31,15 +31,17 @@ import akka.actor.Props
 
 import org.apache.spark._
 import org.apache.spark.deploy.SparkHadoopUtil
-import org.apache.spark.scheduler._
+import org.apache.spark.scheduler.{DirectTaskResult, IndirectTaskResult, Task}
 import org.apache.spark.shuffle.FetchFailedException
 import org.apache.spark.storage.{StorageLevel, TaskResultBlockId}
-import org.apache.spark.util.{ChildFirstURLClassLoader, MutableURLClassLoader,
-  SparkUncaughtExceptionHandler, AkkaUtils, Utils}
+import org.apache.spark.util._
 
 /**
- * Spark executor used with Mesos, YARN, and the standalone scheduler.
- * In coarse-grained mode, an existing actor system is provided.
+ * Spark executor, backed by a threadpool to run tasks.
+ *
+ * This can be used with Mesos, YARN, and the standalone scheduler.
+ * An internal RPC interface (at the moment Akka) is used for communication 
with the driver,
+ * except in the case of Mesos fine-grained mode.
  */
 private[spark] class Executor(
     executorId: String,
@@ -47,8 +49,8 @@ private[spark] class Executor(
     env: SparkEnv,
     userClassPath: Seq[URL] = Nil,
     isLocal: Boolean = false)
-  extends Logging
-{
+  extends Logging {
+
   logInfo(s"Starting executor ID $executorId on host $executorHostname")
 
   // Application dependencies (added through SparkContext) that we've fetched 
so far on this node.
@@ -78,9 +80,8 @@ private[spark] class Executor(
   }
 
   // Start worker thread pool
-  val threadPool = Utils.newDaemonCachedThreadPool("Executor task launch 
worker")
-
-  val executorSource = new ExecutorSource(this, executorId)
+  private val threadPool = Utils.newDaemonCachedThreadPool("Executor task 
launch worker")
+  private val executorSource = new ExecutorSource(threadPool, executorId)
 
   if (!isLocal) {
     env.metricsSystem.registerSource(executorSource)
@@ -122,21 +123,21 @@ private[spark] class Executor(
       taskId: Long,
       attemptNumber: Int,
       taskName: String,
-      serializedTask: ByteBuffer) {
+      serializedTask: ByteBuffer): Unit = {
     val tr = new TaskRunner(context, taskId = taskId, attemptNumber = 
attemptNumber, taskName,
       serializedTask)
     runningTasks.put(taskId, tr)
     threadPool.execute(tr)
   }
 
-  def killTask(taskId: Long, interruptThread: Boolean) {
+  def killTask(taskId: Long, interruptThread: Boolean): Unit = {
     val tr = runningTasks.get(taskId)
     if (tr != null) {
       tr.kill(interruptThread)
     }
   }
 
-  def stop() {
+  def stop(): Unit = {
     env.metricsSystem.report()
     env.actorSystem.stop(executorActor)
     isStopped = true
@@ -146,7 +147,10 @@ private[spark] class Executor(
     }
   }
 
-  private def gcTime = 
ManagementFactory.getGarbageCollectorMXBeans.map(_.getCollectionTime).sum
+  /** Returns the total amount of time this JVM process has spent in garbage 
collection. */
+  private def computeTotalGcTime(): Long = {
+    ManagementFactory.getGarbageCollectorMXBeans.map(_.getCollectionTime).sum
+  }
 
   class TaskRunner(
       execBackend: ExecutorBackend,
@@ -156,12 +160,19 @@ private[spark] class Executor(
       serializedTask: ByteBuffer)
     extends Runnable {
 
+    /** Whether this task has been killed. */
     @volatile private var killed = false
-    @volatile var task: Task[Any] = _
-    @volatile var attemptedTask: Option[Task[Any]] = None
+
+    /** How much the JVM process has spent in GC when the task starts to run. 
*/
     @volatile var startGCTime: Long = _
 
-    def kill(interruptThread: Boolean) {
+    /**
+     * The task to run. This will be set in run() by deserializing the task 
binary coming
+     * from the driver. Once it is set, it will never be changed.
+     */
+    @volatile var task: Task[Any] = _
+
+    def kill(interruptThread: Boolean): Unit = {
       logInfo(s"Executor is trying to kill $taskName (TID $taskId)")
       killed = true
       if (task != null) {
@@ -169,14 +180,14 @@ private[spark] class Executor(
       }
     }
 
-    override def run() {
+    override def run(): Unit = {
       val deserializeStartTime = System.currentTimeMillis()
       Thread.currentThread.setContextClassLoader(replClassLoader)
       val ser = env.closureSerializer.newInstance()
       logInfo(s"Running $taskName (TID $taskId)")
       execBackend.statusUpdate(taskId, TaskState.RUNNING, EMPTY_BYTE_BUFFER)
       var taskStart: Long = 0
-      startGCTime = gcTime
+      startGCTime = computeTotalGcTime()
 
       try {
         val (taskFiles, taskJars, taskBytes) = 
Task.deserializeWithDependencies(serializedTask)
@@ -193,7 +204,6 @@ private[spark] class Executor(
           throw new TaskKilledException
         }
 
-        attemptedTask = Some(task)
         logDebug("Task " + taskId + "'s epoch is " + task.epoch)
         env.mapOutputTracker.updateEpoch(task.epoch)
 
@@ -215,18 +225,17 @@ private[spark] class Executor(
         for (m <- task.metrics) {
           m.setExecutorDeserializeTime(taskStart - deserializeStartTime)
           m.setExecutorRunTime(taskFinish - taskStart)
-          m.setJvmGCTime(gcTime - startGCTime)
+          m.setJvmGCTime(computeTotalGcTime() - startGCTime)
           m.setResultSerializationTime(afterSerialization - 
beforeSerialization)
         }
 
         val accumUpdates = Accumulators.values
-
         val directResult = new DirectTaskResult(valueBytes, accumUpdates, 
task.metrics.orNull)
         val serializedDirectResult = ser.serialize(directResult)
         val resultSize = serializedDirectResult.limit
 
         // directSend = sending directly back to the driver
-        val serializedResult = {
+        val serializedResult: ByteBuffer = {
           if (maxResultSize > 0 && resultSize > maxResultSize) {
             logWarning(s"Finished $taskName (TID $taskId). Result is larger 
than maxResultSize " +
               s"(${Utils.bytesToString(resultSize)} > 
${Utils.bytesToString(maxResultSize)}), " +
@@ -248,42 +257,40 @@ private[spark] class Executor(
         execBackend.statusUpdate(taskId, TaskState.FINISHED, serializedResult)
 
       } catch {
-        case ffe: FetchFailedException => {
+        case ffe: FetchFailedException =>
           val reason = ffe.toTaskEndReason
           execBackend.statusUpdate(taskId, TaskState.FAILED, 
ser.serialize(reason))
-        }
 
-        case _: TaskKilledException | _: InterruptedException if task.killed 
=> {
+        case _: TaskKilledException | _: InterruptedException if task.killed =>
           logInfo(s"Executor killed $taskName (TID $taskId)")
           execBackend.statusUpdate(taskId, TaskState.KILLED, 
ser.serialize(TaskKilled))
-        }
 
-        case cDE: CommitDeniedException => {
+        case cDE: CommitDeniedException =>
           val reason = cDE.toTaskEndReason
           execBackend.statusUpdate(taskId, TaskState.FAILED, 
ser.serialize(reason))
-        }
 
-        case t: Throwable => {
+        case t: Throwable =>
           // Attempt to exit cleanly by informing the driver of our failure.
           // If anything goes wrong (or this was a fatal exception), we will 
delegate to
           // the default uncaught exception handler, which will terminate the 
Executor.
           logError(s"Exception in $taskName (TID $taskId)", t)
 
-          val serviceTime = System.currentTimeMillis() - taskStart
-          val metrics = attemptedTask.flatMap(t => t.metrics)
-          for (m <- metrics) {
-            m.setExecutorRunTime(serviceTime)
-            m.setJvmGCTime(gcTime - startGCTime)
+          val metrics: Option[TaskMetrics] = Option(task).flatMap { task =>
+            task.metrics.map { m =>
+              m.setExecutorRunTime(System.currentTimeMillis() - taskStart)
+              m.setJvmGCTime(computeTotalGcTime() - startGCTime)
+              m
+            }
           }
-          val reason = new ExceptionFailure(t, metrics)
-          execBackend.statusUpdate(taskId, TaskState.FAILED, 
ser.serialize(reason))
+          val taskEndReason = new ExceptionFailure(t, metrics)
+          execBackend.statusUpdate(taskId, TaskState.FAILED, 
ser.serialize(taskEndReason))
 
           // Don't forcibly exit unless the exception was inherently fatal, to 
avoid
           // stopping other tasks unnecessarily.
           if (Utils.isFatalError(t)) {
             SparkUncaughtExceptionHandler.uncaughtException(t)
           }
-        }
+
       } finally {
         // Release memory used by this thread for shuffles
         env.shuffleMemoryManager.releaseMemoryForThisThread()
@@ -358,7 +365,7 @@ private[spark] class Executor(
       for ((name, timestamp) <- newFiles if currentFiles.getOrElse(name, -1L) 
< timestamp) {
         logInfo("Fetching " + name + " with timestamp " + timestamp)
         // Fetch file with useCache mode, close cache for local mode.
-        Utils.fetchFile(name, new File(SparkFiles.getRootDirectory), conf,
+        Utils.fetchFile(name, new File(SparkFiles.getRootDirectory()), conf,
           env.securityManager, hadoopConf, timestamp, useCache = !isLocal)
         currentFiles(name) = timestamp
       }
@@ -370,12 +377,12 @@ private[spark] class Executor(
         if (currentTimeStamp < timestamp) {
           logInfo("Fetching " + name + " with timestamp " + timestamp)
           // Fetch file with useCache mode, close cache for local mode.
-          Utils.fetchFile(name, new File(SparkFiles.getRootDirectory), conf,
+          Utils.fetchFile(name, new File(SparkFiles.getRootDirectory()), conf,
             env.securityManager, hadoopConf, timestamp, useCache = !isLocal)
           currentJars(name) = timestamp
           // Add it to our class loader
-          val url = new File(SparkFiles.getRootDirectory, 
localName).toURI.toURL
-          if (!urlClassLoader.getURLs.contains(url)) {
+          val url = new File(SparkFiles.getRootDirectory(), 
localName).toURI.toURL
+          if (!urlClassLoader.getURLs().contains(url)) {
             logInfo("Adding " + url + " to class loader")
             urlClassLoader.addURL(url)
           }
@@ -384,61 +391,70 @@ private[spark] class Executor(
     }
   }
 
-  def startDriverHeartbeater() {
-    val interval = conf.getInt("spark.executor.heartbeatInterval", 10000)
-    val timeout = AkkaUtils.lookupTimeout(conf)
-    val retryAttempts = AkkaUtils.numRetries(conf)
-    val retryIntervalMs = AkkaUtils.retryWaitMs(conf)
-    val heartbeatReceiverRef = AkkaUtils.makeDriverRef("HeartbeatReceiver", 
conf, env.actorSystem)
+  private val timeout = AkkaUtils.lookupTimeout(conf)
+  private val retryAttempts = AkkaUtils.numRetries(conf)
+  private val retryIntervalMs = AkkaUtils.retryWaitMs(conf)
+  private val heartbeatReceiverRef =
+    AkkaUtils.makeDriverRef("HeartbeatReceiver", conf, env.actorSystem)
+
+  /** Reports heartbeat and metrics for active tasks to the driver. */
+  private def reportHeartBeat(): Unit = {
+    // list of (task id, metrics) to send back to the driver
+    val tasksMetrics = new ArrayBuffer[(Long, TaskMetrics)]()
+    val curGCTime = computeTotalGcTime()
+
+    for (taskRunner <- runningTasks.values()) {
+      if (taskRunner.task != null) {
+        taskRunner.task.metrics.foreach { metrics =>
+          metrics.updateShuffleReadMetrics()
+          metrics.updateInputMetrics()
+          metrics.setJvmGCTime(curGCTime - taskRunner.startGCTime)
+
+          if (isLocal) {
+            // JobProgressListener will hold an reference of it during
+            // onExecutorMetricsUpdate(), then JobProgressListener can not see
+            // the changes of metrics any more, so make a deep copy of it
+            val copiedMetrics = 
Utils.deserialize[TaskMetrics](Utils.serialize(metrics))
+            tasksMetrics += ((taskRunner.taskId, copiedMetrics))
+          } else {
+            // It will be copied by serialization
+            tasksMetrics += ((taskRunner.taskId, metrics))
+          }
+        }
+      }
+    }
 
-    val t = new Thread() {
+    val message = Heartbeat(executorId, tasksMetrics.toArray, 
env.blockManager.blockManagerId)
+    try {
+      val response = AkkaUtils.askWithReply[HeartbeatResponse](message, 
heartbeatReceiverRef,
+        retryAttempts, retryIntervalMs, timeout)
+      if (response.reregisterBlockManager) {
+        logWarning("Told to re-register on heartbeat")
+        env.blockManager.reregister()
+      }
+    } catch {
+      case NonFatal(e) => logWarning("Issue communicating with driver in 
heartbeater", e)
+    }
+  }
+
+  /**
+   * Starts a thread to report heartbeat and partial metrics for active tasks 
to driver.
+   * This thread stops running when the executor is stopped.
+   */
+  private def startDriverHeartbeater(): Unit = {
+    val interval = conf.getInt("spark.executor.heartbeatInterval", 10000)
+    val thread = new Thread() {
       override def run() {
         // Sleep a random interval so the heartbeats don't end up in sync
         Thread.sleep(interval + (math.random * interval).asInstanceOf[Int])
-
         while (!isStopped) {
-          val tasksMetrics = new ArrayBuffer[(Long, TaskMetrics)]()
-          val curGCTime = gcTime
-
-          for (taskRunner <- runningTasks.values()) {
-            if (taskRunner.attemptedTask.nonEmpty) {
-              Option(taskRunner.task).flatMap(_.metrics).foreach { metrics =>
-                metrics.updateShuffleReadMetrics()
-                metrics.updateInputMetrics()
-                metrics.setJvmGCTime(curGCTime - taskRunner.startGCTime)
-
-                if (isLocal) {
-                  // JobProgressListener will hold an reference of it during
-                  // onExecutorMetricsUpdate(), then JobProgressListener can 
not see
-                  // the changes of metrics any more, so make a deep copy of it
-                  val copiedMetrics = 
Utils.deserialize[TaskMetrics](Utils.serialize(metrics))
-                  tasksMetrics += ((taskRunner.taskId, copiedMetrics))
-                } else {
-                  // It will be copied by serialization
-                  tasksMetrics += ((taskRunner.taskId, metrics))
-                }
-              }
-            }
-          }
-
-          val message = Heartbeat(executorId, tasksMetrics.toArray, 
env.blockManager.blockManagerId)
-          try {
-            val response = AkkaUtils.askWithReply[HeartbeatResponse](message, 
heartbeatReceiverRef,
-              retryAttempts, retryIntervalMs, timeout)
-            if (response.reregisterBlockManager) {
-              logWarning("Told to re-register on heartbeat")
-              env.blockManager.reregister()
-            }
-          } catch {
-            case NonFatal(t) => logWarning("Issue communicating with driver in 
heartbeater", t)
-          }
-
+          reportHeartBeat()
           Thread.sleep(interval)
         }
       }
     }
-    t.setDaemon(true)
-    t.setName("Driver Heartbeater")
-    t.start()
+    thread.setDaemon(true)
+    thread.setName("driver-heartbeater")
+    thread.start()
   }
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/0745a305/core/src/main/scala/org/apache/spark/executor/ExecutorSource.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/executor/ExecutorSource.scala 
b/core/src/main/scala/org/apache/spark/executor/ExecutorSource.scala
index c4d7362..293c512 100644
--- a/core/src/main/scala/org/apache/spark/executor/ExecutorSource.scala
+++ b/core/src/main/scala/org/apache/spark/executor/ExecutorSource.scala
@@ -17,6 +17,8 @@
 
 package org.apache.spark.executor
 
+import java.util.concurrent.ThreadPoolExecutor
+
 import scala.collection.JavaConversions._
 
 import com.codahale.metrics.{Gauge, MetricRegistry}
@@ -24,9 +26,11 @@ import org.apache.hadoop.fs.FileSystem
 
 import org.apache.spark.metrics.source.Source
 
-private[spark] class ExecutorSource(val executor: Executor, executorId: 
String) extends Source {
+private[spark]
+class ExecutorSource(threadPool: ThreadPoolExecutor, executorId: String) 
extends Source {
+
   private def fileStats(scheme: String) : Option[FileSystem.Statistics] =
-    FileSystem.getAllStatistics().filter(s => 
s.getScheme.equals(scheme)).headOption
+    FileSystem.getAllStatistics().find(s => s.getScheme.equals(scheme))
 
   private def registerFileSystemStat[T](
         scheme: String, name: String, f: FileSystem.Statistics => T, 
defaultValue: T) = {
@@ -41,23 +45,23 @@ private[spark] class ExecutorSource(val executor: Executor, 
executorId: String)
 
   // Gauge for executor thread pool's actively executing task counts
   metricRegistry.register(MetricRegistry.name("threadpool", "activeTasks"), 
new Gauge[Int] {
-    override def getValue: Int = executor.threadPool.getActiveCount()
+    override def getValue: Int = threadPool.getActiveCount()
   })
 
   // Gauge for executor thread pool's approximate total number of tasks that 
have been completed
   metricRegistry.register(MetricRegistry.name("threadpool", "completeTasks"), 
new Gauge[Long] {
-    override def getValue: Long = executor.threadPool.getCompletedTaskCount()
+    override def getValue: Long = threadPool.getCompletedTaskCount()
   })
 
   // Gauge for executor thread pool's current number of threads
   metricRegistry.register(MetricRegistry.name("threadpool", 
"currentPool_size"), new Gauge[Int] {
-    override def getValue: Int = executor.threadPool.getPoolSize()
+    override def getValue: Int = threadPool.getPoolSize()
   })
 
   // Gauge got executor thread pool's largest number of threads that have ever 
simultaneously
   // been in th pool
   metricRegistry.register(MetricRegistry.name("threadpool", "maxPool_size"), 
new Gauge[Int] {
-    override def getValue: Int = executor.threadPool.getMaximumPoolSize()
+    override def getValue: Int = threadPool.getMaximumPoolSize()
   })
 
   // Gauge for file system stats of this executor

http://git-wip-us.apache.org/repos/asf/spark/blob/0745a305/core/src/main/scala/org/apache/spark/scheduler/Task.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/scheduler/Task.scala 
b/core/src/main/scala/org/apache/spark/scheduler/Task.scala
index 847a491..4d9f940 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/Task.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/Task.scala
@@ -45,7 +45,7 @@ import org.apache.spark.util.Utils
 private[spark] abstract class Task[T](val stageId: Int, var partitionId: Int) 
extends Serializable {
 
   /**
-   * Called by Executor to run this task.
+   * Called by [[Executor]] to run this task.
    *
    * @param taskAttemptId an identifier for this task attempt that is unique 
within a SparkContext.
    * @param attemptNumber how many times this task has been attempted (0 for 
the first attempt)


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to