mridulm commented on code in PR #36162:
URL: https://github.com/apache/spark/pull/36162#discussion_r895335273


##########
core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala:
##########
@@ -1218,6 +1249,71 @@ private[spark] class TaskSetManager(
   def executorAdded(): Unit = {
     recomputeLocality()
   }
+
+  /**
+   * A class for checking inefficient tasks to be speculated, the inefficient 
tasks come from
+   * the tasks which may be speculated by the previous strategy.
+   */
+  private[TaskSetManager] class TaskProcessRateCalculator {
+    private var totalRecordsRead = 0L
+    private var totalExecutorRunTime = 0L
+    private var avgTaskProcessRate = 0.0D
+    private val runingTasksProcessRate = new ConcurrentHashMap[Long, Double]()
+
+    private[TaskSetManager] def updateAvgTaskProcessRate(
+        taskId: Long,
+        result: DirectTaskResult[_]): Unit = {
+      var recordsRead = 0L
+      var executorRunTime = 0L
+      result.accumUpdates.foreach { a =>
+        if (a.name == Some(shuffleRead.RECORDS_READ) ||
+          a.name == Some(input.RECORDS_READ)) {
+          val acc = a.asInstanceOf[LongAccumulator]
+          recordsRead += acc.value
+        } else if (a.name == Some(InternalAccumulator.EXECUTOR_RUN_TIME)) {
+          val acc = a.asInstanceOf[LongAccumulator]
+          executorRunTime = acc.value
+        }
+      }
+      totalRecordsRead += recordsRead
+      totalExecutorRunTime += executorRunTime
+      if (totalRecordsRead > 0 && totalExecutorRunTime > 0) {
+        avgTaskProcessRate = totalRecordsRead / (totalExecutorRunTime / 1000.0)
+      }
+      runingTasksProcessRate.remove(taskId)
+    }
+
+    private[scheduler] def updateRuningTaskProcessRate(
+        taskId: Long,
+        taskProcessRate: Double): Unit = {
+      runingTasksProcessRate.put(taskId, taskProcessRate)
+    }
+
+    private[TaskSetManager] def isEfficient(

Review Comment:
   This should be `isInefficient` right ?



##########
core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala:
##########
@@ -1109,40 +1155,25 @@ private[spark] class TaskSetManager(
     // `successfulTaskDurations` may not equal to `tasksSuccessful`. Here we 
should only count the
     // tasks that are submitted by this `TaskSetManager` and are completed 
successfully.
     val numSuccessfulTasks = successfulTaskDurations.size()
+    val timeMs = clock.getTimeMillis()
     if (numSuccessfulTasks >= minFinishedForSpeculation) {
-      val time = clock.getTimeMillis()
       val medianDuration = successfulTaskDurations.median
       val threshold = max(speculationMultiplier * medianDuration, 
minTimeToSpeculation)
       // TODO: Threshold should also look at standard deviation of task 
durations and have a lower
       // bound based on that.
       logDebug("Task length threshold for speculation: " + threshold)
-      for (tid <- runningTasksSet) {
-        var speculated = checkAndSubmitSpeculatableTask(tid, time, threshold)
-        if (!speculated && executorDecommissionKillInterval.isDefined) {
-          val taskInfo = taskInfos(tid)
-          val decomState = 
sched.getExecutorDecommissionState(taskInfo.executorId)
-          if (decomState.isDefined) {
-            // Check if this task might finish after this executor is 
decommissioned.
-            // We estimate the task's finish time by using the median task 
duration.
-            // Whereas the time when the executor might be decommissioned is 
estimated using the
-            // config executorDecommissionKillInterval. If the task is going 
to finish after
-            // decommissioning, then we will eagerly speculate the task.
-            val taskEndTimeBasedOnMedianDuration = taskInfos(tid).launchTime + 
medianDuration
-            val executorDecomTime = decomState.get.startTime + 
executorDecommissionKillInterval.get
-            val canExceedDeadline = executorDecomTime < 
taskEndTimeBasedOnMedianDuration
-            if (canExceedDeadline) {
-              speculated = checkAndSubmitSpeculatableTask(tid, time, 0)
-            }
-          }
-        }
-        foundTasks |= speculated
-      }
-    } else if (speculationTaskDurationThresOpt.isDefined && 
speculationTasksLessEqToSlots) {
-      val time = clock.getTimeMillis()
-      val threshold = speculationTaskDurationThresOpt.get
+      foundTasks = checkAndSubmitSpeculatableTasks(timeMs, threshold, 
numSuccessfulTasks)
+    } else if (isSpeculationThresholdSpecified && 
speculationTasksLessEqToSlots) {
+      val threshold = max(speculationTaskDurationThresOpt.get, 
minTimeToSpeculation)

Review Comment:
   I seem to be forgetting, why did we change from 
`speculationTaskDurationThresOpt.get` to 
`max(speculationTaskDurationThresOpt.get, minTimeToSpeculation)` here ?



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org

For queries about this service, please contact Infrastructure at:
us...@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org
For additional commands, e-mail: reviews-h...@spark.apache.org

Reply via email to