seojangho closed pull request #52: [NEMO-122] Manage Task/Stage/Job states in one place URL: https://github.com/apache/incubator-nemo/pull/52
This is a PR merged from a forked repository. As GitHub hides the original diff on merge, it is displayed below for the sake of provenance: As this is a foreign pull request (from a fork), the diff is supplied below (as it won't show otherwise due to GitHub magic): diff --git a/client/src/main/java/edu/snu/nemo/client/DriverEndpoint.java b/client/src/main/java/edu/snu/nemo/client/DriverEndpoint.java index 523463ff..fc33eeae 100644 --- a/client/src/main/java/edu/snu/nemo/client/DriverEndpoint.java +++ b/client/src/main/java/edu/snu/nemo/client/DriverEndpoint.java @@ -54,7 +54,7 @@ public DriverEndpoint(final JobStateManager jobStateManager, * @return the current state of the running job. */ JobState.State getState() { - return (JobState.State) jobStateManager.getJobState().getStateMachine().getCurrentState(); + return jobStateManager.getJobState(); } /** @@ -67,7 +67,7 @@ public DriverEndpoint(final JobStateManager jobStateManager, */ JobState.State waitUntilFinish(final long timeout, final TimeUnit unit) { - return (JobState.State) jobStateManager.waitUntilFinish(timeout, unit).getStateMachine().getCurrentState(); + return jobStateManager.waitUntilFinish(timeout, unit); } /** @@ -76,6 +76,6 @@ public DriverEndpoint(final JobStateManager jobStateManager, * @return the final state of this job. */ JobState.State waitUntilFinish() { - return (JobState.State) jobStateManager.waitUntilFinish().getStateMachine().getCurrentState(); + return jobStateManager.waitUntilFinish(); } } diff --git a/runtime/common/src/main/java/edu/snu/nemo/runtime/common/state/StageState.java b/runtime/common/src/main/java/edu/snu/nemo/runtime/common/state/StageState.java index 81a09801..6bf33805 100644 --- a/runtime/common/src/main/java/edu/snu/nemo/runtime/common/state/StageState.java +++ b/runtime/common/src/main/java/edu/snu/nemo/runtime/common/state/StageState.java @@ -35,18 +35,15 @@ private StateMachine buildTaskStateMachine() { stateMachineBuilder.addState(State.EXECUTING, "The stage is executing."); stateMachineBuilder.addState(State.COMPLETE, "All of this stage's tasks have completed."); stateMachineBuilder.addState(State.FAILED_RECOVERABLE, "Stage failed, but is recoverable."); - stateMachineBuilder.addState(State.FAILED_UNRECOVERABLE, "Stage failed, and is unrecoverable. The job will fail."); // Add transitions stateMachineBuilder.addTransition(State.READY, State.EXECUTING, "The stage can now schedule its tasks"); - stateMachineBuilder.addTransition(State.READY, State.FAILED_UNRECOVERABLE, - "Job Failure"); + stateMachineBuilder.addTransition(State.READY, State.FAILED_RECOVERABLE, + "Recoverable failure"); stateMachineBuilder.addTransition(State.EXECUTING, State.COMPLETE, "All tasks complete"); - stateMachineBuilder.addTransition(State.EXECUTING, State.FAILED_UNRECOVERABLE, - "Unrecoverable failure in a task"); stateMachineBuilder.addTransition(State.EXECUTING, State.FAILED_RECOVERABLE, "Recoverable failure in a task"); @@ -55,8 +52,8 @@ private StateMachine buildTaskStateMachine() { stateMachineBuilder.addTransition(State.FAILED_RECOVERABLE, State.READY, "Recoverable stage failure"); - stateMachineBuilder.addTransition(State.FAILED_RECOVERABLE, State.FAILED_UNRECOVERABLE, - ""); + stateMachineBuilder.addTransition(State.FAILED_RECOVERABLE, State.EXECUTING, + "Recoverable stage failure"); stateMachineBuilder.setInitialState(State.READY); @@ -75,7 +72,6 @@ public StateMachine getStateMachine() { EXECUTING, COMPLETE, FAILED_RECOVERABLE, - FAILED_UNRECOVERABLE } @Override diff --git a/runtime/common/src/main/java/edu/snu/nemo/runtime/common/state/TaskState.java b/runtime/common/src/main/java/edu/snu/nemo/runtime/common/state/TaskState.java index e201e323..b47696af 100644 --- a/runtime/common/src/main/java/edu/snu/nemo/runtime/common/state/TaskState.java +++ b/runtime/common/src/main/java/edu/snu/nemo/runtime/common/state/TaskState.java @@ -35,37 +35,35 @@ private StateMachine buildTaskStateMachine() { stateMachineBuilder.addState(State.EXECUTING, "The task is executing."); stateMachineBuilder.addState(State.COMPLETE, "The task has completed."); stateMachineBuilder.addState(State.FAILED_RECOVERABLE, "Task failed, but is recoverable."); - stateMachineBuilder.addState(State.FAILED_UNRECOVERABLE, - "Task failed, and is unrecoverable. The job will fail."); + stateMachineBuilder.addState(State.FAILED_UNRECOVERABLE, "Task failed, and is unrecoverable. The job will fail."); stateMachineBuilder.addState(State.ON_HOLD, "The task is paused for dynamic optimization."); - // Add transitions - stateMachineBuilder.addTransition(State.READY, State.EXECUTING, - "Scheduling to executor"); + // From NOT_AVAILABLE + stateMachineBuilder.addTransition(State.READY, State.EXECUTING, "Scheduling to executor"); stateMachineBuilder.addTransition(State.READY, State.FAILED_RECOVERABLE, "Stage Failure by a recoverable failure in another task"); - stateMachineBuilder.addTransition(State.READY, State.FAILED_UNRECOVERABLE, - "Stage Failure"); - - stateMachineBuilder.addTransition(State.EXECUTING, State.COMPLETE, - "All tasks complete"); - stateMachineBuilder.addTransition(State.EXECUTING, State.FAILED_UNRECOVERABLE, - "Unrecoverable failure in a task/Executor failure"); - stateMachineBuilder.addTransition(State.EXECUTING, State.FAILED_RECOVERABLE, - "Recoverable failure in a task/Container failure"); + + // From EXECUTING + stateMachineBuilder.addTransition(State.EXECUTING, State.COMPLETE, "Task completed normally"); + stateMachineBuilder.addTransition(State.EXECUTING, State.FAILED_UNRECOVERABLE, "Unrecoverable failure"); + stateMachineBuilder.addTransition(State.EXECUTING, State.FAILED_RECOVERABLE, "Recoverable failure"); stateMachineBuilder.addTransition(State.EXECUTING, State.ON_HOLD, "Task paused for dynamic optimization"); - stateMachineBuilder.addTransition(State.ON_HOLD, State.COMPLETE, "Task completed after dynamic optimization"); + // From ON HOLD + stateMachineBuilder.addTransition(State.ON_HOLD, State.COMPLETE, "Task completed after being on hold"); + stateMachineBuilder.addTransition(State.ON_HOLD, State.FAILED_UNRECOVERABLE, "Unrecoverable failure"); + stateMachineBuilder.addTransition(State.ON_HOLD, State.FAILED_RECOVERABLE, "Recoverable failure"); + + // From COMPLETE + stateMachineBuilder.addTransition(State.COMPLETE, State.EXECUTING, "Completed before, but re-execute"); stateMachineBuilder.addTransition(State.COMPLETE, State.FAILED_RECOVERABLE, "Recoverable failure in a task/Container failure"); - stateMachineBuilder.addTransition(State.FAILED_RECOVERABLE, State.READY, - "Recovered from failure and is ready"); - stateMachineBuilder.addTransition(State.FAILED_RECOVERABLE, State.FAILED_UNRECOVERABLE, - ""); - stateMachineBuilder.setInitialState(State.READY); + // From FAILED_RECOVERABLE + stateMachineBuilder.addTransition(State.FAILED_RECOVERABLE, State.READY, "Recovered from failure and is ready"); + stateMachineBuilder.setInitialState(State.READY); return stateMachineBuilder.build(); } diff --git a/runtime/master/src/main/java/edu/snu/nemo/runtime/master/JobStateManager.java b/runtime/master/src/main/java/edu/snu/nemo/runtime/master/JobStateManager.java index 0928caf4..220e5854 100644 --- a/runtime/master/src/main/java/edu/snu/nemo/runtime/master/JobStateManager.java +++ b/runtime/master/src/main/java/edu/snu/nemo/runtime/master/JobStateManager.java @@ -43,22 +43,24 @@ import org.apache.reef.annotations.audience.DriverSide; import org.slf4j.Logger; import org.slf4j.LoggerFactory; -import java.util.stream.Collectors; + +import javax.annotation.concurrent.ThreadSafe; import java.util.stream.IntStream; import static edu.snu.nemo.common.dag.DAG.EMPTY_DAG_DIRECTORY; /** - * Manages the states related to a job. - * This class can be used to track a job's execution status to task level in the future. - * The methods of this class are synchronized. + * Maintains three levels of state machines (JobState, StageState, and TaskState) of a physical plan. + * The main API this class provides is onTaskStateReportFromExecutor(), which directly changes a TaskState. + * JobState and StageState are updated internally in the class, and can only be read from the outside. + * + * (CONCURRENCY) The public methods of this class are synchronized. */ @DriverSide +@ThreadSafe public final class JobStateManager { private static final Logger LOG = LoggerFactory.getLogger(JobStateManager.class.getName()); - private final String jobId; - private final int maxScheduleAttempt; /** @@ -74,38 +76,20 @@ */ private final Map<String, Integer> taskIdToCurrentAttempt; - /** - * Keeps track of the number of schedule attempts for each stage. - */ - private final Map<String, Integer> scheduleAttemptIdxByStage; - /** * Represents the job to manage. */ private final PhysicalPlan physicalPlan; - /** - * Used to track stage completion status. - * All task ids are added to the set when the a stage begins executing. - * Each task id is removed upon completion, - * therefore indicating the stage's completion when this set becomes empty. - */ - private final Map<String, Set<String>> stageIdToRemainingTaskSet; - - /** - * Used to track job completion status. - * All stage ids are added to the set when the this job begins executing. - * Each stage id is removed upon completion, - * therefore indicating the job's completion when this set becomes empty. - */ - private final Set<String> currentJobStageIds; - /** * A lock and condition to check whether the job is finished or not. */ private final Lock finishLock; private final Condition jobFinishedCondition; + /** + * For metrics. + */ private final MetricMessageHandler metricMessageHandler; private final Map<String, MetricDataBuilder> metricDataBuilderMap; @@ -121,9 +105,6 @@ public JobStateManager(final PhysicalPlan physicalPlan, this.idToStageStates = new HashMap<>(); this.idToTaskStates = new HashMap<>(); this.taskIdToCurrentAttempt = new HashMap<>(); - this.scheduleAttemptIdxByStage = new HashMap<>(); - this.stageIdToRemainingTaskSet = new HashMap<>(); - this.currentJobStageIds = new HashSet<>(); this.finishLock = new ReentrantLock(); this.jobFinishedCondition = finishLock.newCondition(); this.metricDataBuilderMap = new HashMap<>(); @@ -139,7 +120,6 @@ private void initializeComputationStates() { // Initialize the states for the job down to task-level. physicalPlan.getStageDAG().topologicalDo(stage -> { - currentJobStageIds.add(stage.getId()); idToStageStates.put(stage.getId(), new StageState()); stage.getTaskIds().forEach(taskId -> { idToTaskStates.put(taskId, new TaskState()); @@ -179,216 +159,176 @@ private void initializePartitionStates(final BlockManagerMaster blockManagerMast } /** - * Updates the state of the job. - * @param newState of the job. + * Updates the state of a task. + * Task state changes can occur both in master and executor. + * State changes that occur in master are + * initiated in {@link edu.snu.nemo.runtime.master.scheduler.BatchSingleJobScheduler}. + * State changes that occur in executors are sent to master as a control message, + * and the call to this method is initiated in {@link edu.snu.nemo.runtime.master.scheduler.BatchSingleJobScheduler} + * when the message/event is received. + * + * @param taskId the ID of the task. + * @param newTaskState the new state of the task. */ - public synchronized void onJobStateChanged(final JobState.State newState) { + public synchronized void onTaskStateChanged(final String taskId, final TaskState.State newTaskState) { + // Change task state + final StateMachine taskState = idToTaskStates.get(taskId).getStateMachine(); + LOG.debug("Task State Transition: id {}, from {} to {}", + new Object[]{taskId, taskState.getCurrentState(), newTaskState}); + + taskState.setState(newTaskState); + + // Handle metrics final Map<String, Object> metric = new HashMap<>(); + switch (newTaskState) { + case ON_HOLD: + case COMPLETE: + case FAILED_UNRECOVERABLE: + case FAILED_RECOVERABLE: + metric.put("ToState", newTaskState); + endMeasurement(taskId, metric); + break; + case EXECUTING: + metric.put("FromState", newTaskState); + beginMeasurement(taskId, metric); + break; + case READY: + final int currentAttempt = taskIdToCurrentAttempt.get(taskId) + 1; + metric.put("ScheduleAttempt", currentAttempt); + if (currentAttempt <= maxScheduleAttempt) { + taskIdToCurrentAttempt.put(taskId, currentAttempt); + } else { + throw new SchedulingException(new Throwable("Exceeded max number of scheduling attempts for " + taskId)); + } + break; + default: + throw new UnknownExecutionStateException(new Throwable("This task state is unknown")); + } - if (newState == JobState.State.EXECUTING) { - LOG.debug("Executing Job ID {}...", this.jobId); - jobState.getStateMachine().setState(newState); - metric.put("FromState", newState); - beginMeasurement(jobId, metric); - } else if (newState == JobState.State.COMPLETE || newState == JobState.State.FAILED) { - LOG.debug("Job ID {} {}!", new Object[]{jobId, newState}); - // Awake all threads waiting the finish of this job. - finishLock.lock(); - try { - jobState.getStateMachine().setState(newState); - metric.put("ToState", newState); - endMeasurement(jobId, metric); + // Change stage state, if needed + final String stageId = RuntimeIdGenerator.getStageIdFromTaskId(taskId); + final List<String> tasksOfThisStage = physicalPlan.getStageDAG().getVertexById(stageId).getTaskIds(); + final long numOfCompletedOrOnHoldTasksInThisStage = tasksOfThisStage + .stream() + .map(this::getTaskState) + .filter(state -> state.equals(TaskState.State.COMPLETE) || state.equals(TaskState.State.ON_HOLD)) + .count(); + switch (newTaskState) { + case READY: + onStageStateChanged(stageId, StageState.State.READY); + break; + case EXECUTING: + onStageStateChanged(stageId, StageState.State.EXECUTING); + break; + case FAILED_RECOVERABLE: + onStageStateChanged(stageId, StageState.State.FAILED_RECOVERABLE); + break; + case COMPLETE: + case ON_HOLD: + if (numOfCompletedOrOnHoldTasksInThisStage == tasksOfThisStage.size()) { + onStageStateChanged(stageId, StageState.State.COMPLETE); + } + break; + case FAILED_UNRECOVERABLE: + break; + default: + throw new UnknownExecutionStateException(new Throwable("This task state is unknown")); + } - jobFinishedCondition.signalAll(); - } finally { - finishLock.unlock(); - } - } else { - throw new IllegalStateTransitionException(new Exception("Illegal Job State Transition")); + // Log not-yet-completed tasks for us to track progress + if (newTaskState.equals(TaskState.State.COMPLETE)) { + LOG.info("{}: {} Task(s) to go", stageId, tasksOfThisStage.size() - numOfCompletedOrOnHoldTasksInThisStage); } } /** + * (PRIVATE METHOD) * Updates the state of a stage. - * Stage state changes only occur in master. * @param stageId of the stage. - * @param newState of the stage. + * @param newStageState of the stage. */ - public synchronized void onStageStateChanged(final String stageId, final StageState.State newState) { + private void onStageStateChanged(final String stageId, final StageState.State newStageState) { + if (newStageState.equals(getStageState(stageId))) { + // Ignore duplicate state updates + return; + } + + // Change stage state final StateMachine stageStateMachine = idToStageStates.get(stageId).getStateMachine(); LOG.debug("Stage State Transition: id {} from {} to {}", - new Object[]{stageId, stageStateMachine.getCurrentState(), newState}); - stageStateMachine.setState(newState); - final Map<String, Object> metric = new HashMap<>(); - - if (newState == StageState.State.EXECUTING) { - if (scheduleAttemptIdxByStage.containsKey(stageId)) { - final int numAttempts = scheduleAttemptIdxByStage.get(stageId); - - if (numAttempts < maxScheduleAttempt) { - scheduleAttemptIdxByStage.put(stageId, numAttempts + 1); - } else { - throw new SchedulingException( - new Throwable("Exceeded max number of scheduling attempts for " + stageId)); - } - } else { - scheduleAttemptIdxByStage.put(stageId, 1); - } + new Object[]{stageId, stageStateMachine.getCurrentState(), newStageState}); + stageStateMachine.setState(newStageState); - metric.put("ScheduleAttempt", scheduleAttemptIdxByStage.get(stageId)); - metric.put("FromState", newState); + // Metric handling + final Map<String, Object> metric = new HashMap<>(); + if (newStageState == StageState.State.EXECUTING) { + metric.put("FromState", newStageState); beginMeasurement(stageId, metric); - - // if there exists a mapping, this state change is from a failed_recoverable stage, - // and there may be tasks that do not need to be re-executed. - if (!stageIdToRemainingTaskSet.containsKey(stageId)) { - for (final Stage stage : physicalPlan.getStageDAG().getVertices()) { - if (stage.getId().equals(stageId)) { - Set<String> remainingTaskIds = new HashSet<>(); - remainingTaskIds.addAll( - stage.getTaskIds().stream().collect(Collectors.toSet())); - stageIdToRemainingTaskSet.put(stageId, remainingTaskIds); - break; - } - } - } - } else if (newState == StageState.State.COMPLETE) { - metric.put("ToState", newState); + } else if (newStageState == StageState.State.COMPLETE) { + metric.put("ToState", newStageState); endMeasurement(stageId, metric); + } - currentJobStageIds.remove(stageId); - if (currentJobStageIds.isEmpty()) { - onJobStateChanged(JobState.State.COMPLETE); - } - } else if (newState == StageState.State.FAILED_RECOVERABLE) { - metric.put("ToState", newState); - endMeasurement(stageId, metric); - currentJobStageIds.add(stageId); - } else if (newState == StageState.State.FAILED_UNRECOVERABLE) { - metric.put("ToState", newState); - endMeasurement(stageId, metric); + // Change job state if needed + final boolean allStagesCompleted = idToStageStates.values().stream().allMatch(state -> + state.getStateMachine().getCurrentState().equals(StageState.State.COMPLETE)); + + // (1) Job becomes EXECUTING if not already + if (newStageState.equals(StageState.State.EXECUTING) + && !getJobState().equals(JobState.State.EXECUTING)) { + onJobStateChanged(JobState.State.EXECUTING); + } + // (2) Job becomes COMPLETE + if (allStagesCompleted) { + onJobStateChanged(JobState.State.COMPLETE); } } /** - * Updates the state of a task. - * Task state changes can occur both in master and executor. - * State changes that occur in master are - * initiated in {@link edu.snu.nemo.runtime.master.scheduler.BatchSingleJobScheduler}. - * State changes that occur in executors are sent to master as a control message, - * and the call to this method is initiated in {@link edu.snu.nemo.runtime.master.scheduler.BatchSingleJobScheduler} - * when the message/event is received. - * - * @param taskId the ID of the task. - * @param newState the new state of the task. + * (PRIVATE METHOD) + * Updates the state of the job. + * @param newState of the job. */ - public synchronized void onTaskStateChanged(final String taskId, final TaskState.State newState) { - final StateMachine taskState = idToTaskStates.get(taskId).getStateMachine(); - final String stageId = RuntimeIdGenerator.getStageIdFromTaskId(taskId); - - LOG.debug("Task State Transition: id {}, from {} to {}", - new Object[]{taskId, taskState.getCurrentState(), newState}); - final Map<String, Object> metric = new HashMap<>(); - - switch (newState) { - case ON_HOLD: - case COMPLETE: - taskState.setState(newState); - metric.put("ToState", newState); - endMeasurement(taskId, metric); + private void onJobStateChanged(final JobState.State newState) { + if (newState.equals(getJobState())) { + // Ignore duplicate state updates + return; + } - if (stageIdToRemainingTaskSet.containsKey(stageId)) { - final Set<String> remainingTasks = stageIdToRemainingTaskSet.get(stageId); - LOG.info("{}: {} Task(s) to go", stageId, remainingTasks.size()); - remainingTasks.remove(taskId); + jobState.getStateMachine().setState(newState); - if (remainingTasks.isEmpty()) { - onStageStateChanged(stageId, StageState.State.COMPLETE); - } - } else { - throw new IllegalStateTransitionException( - new Throwable("The stage has not yet been submitted for execution")); - } - break; - case EXECUTING: - taskState.setState(newState); + final Map<String, Object> metric = new HashMap<>(); + if (newState == JobState.State.EXECUTING) { + LOG.debug("Executing Job ID {}...", this.jobId); metric.put("FromState", newState); - beginMeasurement(taskId, metric); - break; - case FAILED_RECOVERABLE: - // Multiple calls to set a task's state to failed_recoverable can occur when - // a task is made failed_recoverable early by another task's failure detection in the same stage - // and the task finds itself failed_recoverable later, propagating the state change event only then. - if (taskState.getCurrentState() != TaskState.State.FAILED_RECOVERABLE) { - taskState.setState(newState); - metric.put("ToState", newState); - endMeasurement(taskId, metric); - - // Mark this stage as failed_recoverable as long as it contains at least one failed_recoverable task - if (idToStageStates.get(stageId).getStateMachine().getCurrentState() != StageState.State.FAILED_RECOVERABLE) { - onStageStateChanged(stageId, StageState.State.FAILED_RECOVERABLE); - } - - if (stageIdToRemainingTaskSet.containsKey(stageId)) { - stageIdToRemainingTaskSet.get(stageId).add(taskId); - } else { - throw new IllegalStateTransitionException( - new Throwable("The stage has not yet been submitted for execution")); - } + beginMeasurement(jobId, metric); + } else if (newState == JobState.State.COMPLETE || newState == JobState.State.FAILED) { + LOG.debug("Job ID {} {}!", new Object[]{jobId, newState}); - // We'll recover and retry this task - taskIdToCurrentAttempt.put(taskId, taskIdToCurrentAttempt.get(taskId) + 1); - } else { - LOG.info("{} state is already FAILED_RECOVERABLE. Skipping this event.", - taskId); - } - break; - case READY: - taskState.setState(newState); - break; - case FAILED_UNRECOVERABLE: - taskState.setState(newState); + // Awake all threads waiting the finish of this job. + finishLock.lock(); metric.put("ToState", newState); - endMeasurement(taskId, metric); - break; - default: - throw new UnknownExecutionStateException(new Throwable("This task state is unknown")); - } - } + endMeasurement(jobId, metric); - public synchronized boolean checkStageCompletion(final String stageId) { - return stageIdToRemainingTaskSet.get(stageId).isEmpty(); - } - - public synchronized boolean checkJobTermination() { - final Enum currentState = jobState.getStateMachine().getCurrentState(); - return (currentState == JobState.State.COMPLETE || currentState == JobState.State.FAILED); - } - - public synchronized int getAttemptCountForStage(final String stageId) { - if (scheduleAttemptIdxByStage.containsKey(stageId)) { - return scheduleAttemptIdxByStage.get(stageId); + try { + jobFinishedCondition.signalAll(); + } finally { + finishLock.unlock(); + } } else { - throw new IllegalStateException("No mapping for this stage's attemptIdx, an inconsistent state occurred."); + throw new IllegalStateTransitionException(new Exception("Illegal Job State Transition")); } } - public synchronized int getCurrentAttemptIndexForTask(final String taskId) { - if (taskIdToCurrentAttempt.containsKey(taskId)) { - return taskIdToCurrentAttempt.get(taskId); - } else { - throw new IllegalStateException("No mapping for this task's attemptIdx, an inconsistent state occurred."); - } - } /** * Wait for this job to be finished and return the final state. * @return the final state of this job. */ - public JobState waitUntilFinish() { + public JobState.State waitUntilFinish() { finishLock.lock(); try { - if (!checkJobTermination()) { + if (!isJobDone()) { jobFinishedCondition.await(); } } catch (final InterruptedException e) { @@ -407,11 +347,10 @@ public JobState waitUntilFinish() { * @param unit of the timeout. * @return the final state of this job. */ - public JobState waitUntilFinish(final long timeout, - final TimeUnit unit) { + public JobState.State waitUntilFinish(final long timeout, final TimeUnit unit) { finishLock.lock(); try { - if (!checkJobTermination()) { + if (!isJobDone()) { if (!jobFinishedCondition.await(timeout, unit)) { LOG.warn("Timeout during waiting the finish of Job ID {}", jobId); } @@ -425,28 +364,31 @@ public JobState waitUntilFinish(final long timeout, return getJobState(); } + public synchronized boolean isJobDone() { + return (getJobState() == JobState.State.COMPLETE || getJobState() == JobState.State.FAILED); + } public synchronized String getJobId() { return jobId; } - public synchronized JobState getJobState() { - return jobState; - } - - public synchronized StageState getStageState(final String stageId) { - return idToStageStates.get(stageId); + public synchronized JobState.State getJobState() { + return (JobState.State) jobState.getStateMachine().getCurrentState(); } - public synchronized Map<String, StageState> getIdToStageStates() { - return idToStageStates; + public synchronized StageState.State getStageState(final String stageId) { + return (StageState.State) idToStageStates.get(stageId).getStateMachine().getCurrentState(); } - public synchronized TaskState getTaskState(final String taskId) { - return idToTaskStates.get(taskId); + public synchronized TaskState.State getTaskState(final String taskId) { + return (TaskState.State) idToTaskStates.get(taskId).getStateMachine().getCurrentState(); } - public synchronized Map<String, TaskState> getIdToTaskStates() { - return idToTaskStates; + public synchronized int getTaskAttempt(final String taskId) { + if (taskIdToCurrentAttempt.containsKey(taskId)) { + return taskIdToCurrentAttempt.get(taskId); + } else { + throw new IllegalStateException("No mapping for this task's attemptIdx, an inconsistent state occurred."); + } } /** diff --git a/runtime/master/src/main/java/edu/snu/nemo/runtime/master/RuntimeMaster.java b/runtime/master/src/main/java/edu/snu/nemo/runtime/master/RuntimeMaster.java index a424eccb..a7df05a0 100644 --- a/runtime/master/src/main/java/edu/snu/nemo/runtime/master/RuntimeMaster.java +++ b/runtime/master/src/main/java/edu/snu/nemo/runtime/master/RuntimeMaster.java @@ -275,7 +275,7 @@ private void handleControlMessage(final ControlMessage.Message message) { final ControlMessage.TaskStateChangedMsg taskStateChangedMsg = message.getTaskStateChangedMsg(); - scheduler.onTaskStateChanged(taskStateChangedMsg.getExecutorId(), + scheduler.onTaskStateReportFromExecutor(taskStateChangedMsg.getExecutorId(), taskStateChangedMsg.getTaskId(), taskStateChangedMsg.getAttemptIdx(), convertTaskState(taskStateChangedMsg.getState()), diff --git a/runtime/master/src/main/java/edu/snu/nemo/runtime/master/scheduler/BatchSingleJobScheduler.java b/runtime/master/src/main/java/edu/snu/nemo/runtime/master/scheduler/BatchSingleJobScheduler.java index d06e1b5b..3506063c 100644 --- a/runtime/master/src/main/java/edu/snu/nemo/runtime/master/scheduler/BatchSingleJobScheduler.java +++ b/runtime/master/src/main/java/edu/snu/nemo/runtime/master/scheduler/BatchSingleJobScheduler.java @@ -43,6 +43,7 @@ import org.slf4j.Logger; import static edu.snu.nemo.runtime.common.state.TaskState.State.ON_HOLD; +import static edu.snu.nemo.runtime.common.state.TaskState.State.READY; /** * (WARNING) Only a single dedicated thread should use the public methods of this class. @@ -139,25 +140,25 @@ public void updateJob(final String jobId, final PhysicalPlan newPhysicalPlan, fi * @param vertexPutOnHold the ID of vertex that is put on hold. It is null otherwise. */ @Override - public void onTaskStateChanged(final String executorId, - final String taskId, - final int taskAttemptIndex, - final TaskState.State newState, - @Nullable final String vertexPutOnHold, - final TaskState.RecoverableFailureCause failureCause) { - final int currentTaskAttemptIndex = jobStateManager.getCurrentAttemptIndexForTask(taskId); + public void onTaskStateReportFromExecutor(final String executorId, + final String taskId, + final int taskAttemptIndex, + final TaskState.State newState, + @Nullable final String vertexPutOnHold, + final TaskState.RecoverableFailureCause failureCause) { + final int currentTaskAttemptIndex = jobStateManager.getTaskAttempt(taskId); + if (taskAttemptIndex == currentTaskAttemptIndex) { // Do change state, as this notification is for the current task attempt. + jobStateManager.onTaskStateChanged(taskId, newState); switch (newState) { case COMPLETE: - jobStateManager.onTaskStateChanged(taskId, newState); onTaskExecutionComplete(executorId, taskId); break; case FAILED_RECOVERABLE: - onTaskExecutionFailedRecoverable(executorId, taskId, newState, failureCause); + onTaskExecutionFailedRecoverable(executorId, taskId, failureCause); break; case ON_HOLD: - jobStateManager.onTaskStateChanged(taskId, newState); onTaskExecutionOnHold(executorId, taskId, vertexPutOnHold); break; case FAILED_UNRECOVERABLE: @@ -201,8 +202,8 @@ public void onExecutorRemoved(final String executorId) { }); tasksToReExecute.forEach(failedTaskId -> { - final int attemptIndex = jobStateManager.getCurrentAttemptIndexForTask(failedTaskId); - onTaskStateChanged(executorId, failedTaskId, attemptIndex, TaskState.State.FAILED_RECOVERABLE, + final int attemptIndex = jobStateManager.getTaskAttempt(failedTaskId); + onTaskStateReportFromExecutor(executorId, failedTaskId, attemptIndex, TaskState.State.FAILED_RECOVERABLE, null, TaskState.RecoverableFailureCause.CONTAINER_FAILURE); }); @@ -289,8 +290,7 @@ private void scheduleNextStage(final String completedStageId) { // We need to reschedule failed_recoverable stages. for (final Stage stageToCheck : currentScheduleGroup) { - final StageState.State stageState = - (StageState.State) jobStateManager.getStageState(stageToCheck.getId()).getStateMachine().getCurrentState(); + final StageState.State stageState = jobStateManager.getStageState(stageToCheck.getId()); switch (stageState) { case FAILED_RECOVERABLE: stagesToSchedule.add(stageToCheck); @@ -315,10 +315,8 @@ private void scheduleNextStage(final String completedStageId) { physicalPlan.getStageDAG().getTopologicalSort().stream().filter(stage -> { if (stage.getScheduleGroupIndex() == currentScheduleGroupIndex + 1) { final String stageId = stage.getId(); - return jobStateManager.getStageState(stageId).getStateMachine().getCurrentState() - != StageState.State.EXECUTING - && jobStateManager.getStageState(stageId).getStateMachine().getCurrentState() - != StageState.State.COMPLETE; + return jobStateManager.getStageState(stageId) != StageState.State.EXECUTING + && jobStateManager.getStageState(stageId) != StageState.State.COMPLETE; } return false; }).collect(Collectors.toList()); @@ -346,15 +344,13 @@ private void scheduleStage(final Stage stageToSchedule) { final List<StageEdge> stageOutgoingEdges = physicalPlan.getStageDAG().getOutgoingEdgesOf(stageToSchedule.getId()); - final Enum stageState = jobStateManager.getStageState(stageToSchedule.getId()).getStateMachine().getCurrentState(); + final StageState.State stageState = jobStateManager.getStageState(stageToSchedule.getId()); final List<String> taskIdsToSchedule = new LinkedList<>(); for (final String taskId : stageToSchedule.getTaskIds()) { // this happens when the belonging stage's other tasks have failed recoverable, // but this task's results are safe. - final TaskState.State taskState = - (TaskState.State) - jobStateManager.getTaskState(taskId).getStateMachine().getCurrentState(); + final TaskState.State taskState = jobStateManager.getTaskState(taskId); switch (taskState) { case COMPLETE: @@ -372,7 +368,7 @@ private void scheduleStage(final Stage stageToSchedule) { break; case FAILED_RECOVERABLE: LOG.info("Re-scheduling {} for failure recovery", taskId); - jobStateManager.onTaskStateChanged(taskId, TaskState.State.READY); + jobStateManager.onTaskStateChanged(taskId, READY); taskIdsToSchedule.add(taskId); break; case ON_HOLD: @@ -382,13 +378,7 @@ private void scheduleStage(final Stage stageToSchedule) { throw new SchedulingException(new Throwable("Detected a FAILED_UNRECOVERABLE Task")); } } - if (stageState == StageState.State.FAILED_RECOVERABLE) { - // The 'failed_recoverable' stage has been selected as the next stage to execute. Change its state back to 'ready' - jobStateManager.onStageStateChanged(stageToSchedule.getId(), StageState.State.READY); - } - // attemptIdx is only initialized/updated when we set the stage's state to executing - jobStateManager.onStageStateChanged(stageToSchedule.getId(), StageState.State.EXECUTING); LOG.info("Scheduling Stage {}", stageToSchedule.getId()); // each readable and source task will be bounded in executor. @@ -397,7 +387,7 @@ private void scheduleStage(final Stage stageToSchedule) { taskIdsToSchedule.forEach(taskId -> { blockManagerMaster.onProducerTaskScheduled(taskId); final int taskIdx = RuntimeIdGenerator.getIndexFromTaskId(taskId); - final int attemptIdx = jobStateManager.getCurrentAttemptIndexForTask(taskId); + final int attemptIdx = jobStateManager.getTaskAttempt(taskId); LOG.debug("Enqueueing {}", taskId); pendingTaskCollection.add(new Task( @@ -464,9 +454,9 @@ private void onTaskExecutionComplete(final String executorId, } final String stageIdForTaskUponCompletion = RuntimeIdGenerator.getStageIdFromTaskId(taskId); - if (jobStateManager.checkStageCompletion(stageIdForTaskUponCompletion)) { + if (jobStateManager.getStageState(stageIdForTaskUponCompletion).equals(StageState.State.COMPLETE)) { // if the stage this task belongs to is complete, - if (!jobStateManager.checkJobTermination()) { // and if the job is not yet complete or failed, + if (!jobStateManager.isJobDone()) { scheduleNextStage(stageIdForTaskUponCompletion); } } @@ -490,7 +480,7 @@ private void onTaskExecutionOnHold(final String executorId, final String stageIdForTaskUponCompletion = RuntimeIdGenerator.getStageIdFromTaskId(taskId); final boolean stageComplete = - jobStateManager.checkStageCompletion(stageIdForTaskUponCompletion); + jobStateManager.getStageState(stageIdForTaskUponCompletion).equals(StageState.State.COMPLETE); if (stageComplete) { // get optimization vertex from the task. @@ -517,12 +507,10 @@ private void onTaskExecutionOnHold(final String executorId, * Action for after task execution has failed but it's recoverable. * @param executorId the ID of the executor * @param taskId the ID of the task - * @param newState the state this situation * @param failureCause the cause of failure */ private void onTaskExecutionFailedRecoverable(final String executorId, final String taskId, - final TaskState.State newState, final TaskState.RecoverableFailureCause failureCause) { LOG.info("{} failed in {} by {}", taskId, executorId, failureCause); executorRegistry.updateExecutor(executorId, (executor, state) -> { @@ -535,38 +523,12 @@ private void onTaskExecutionFailedRecoverable(final String executorId, switch (failureCause) { // Previous task must be re-executed, and incomplete tasks of the belonging stage must be rescheduled. case INPUT_READ_FAILURE: - jobStateManager.onTaskStateChanged(taskId, newState); - LOG.info("All tasks of {} will be made failed_recoverable.", stageId); - for (final Stage stage : physicalPlan.getStageDAG().getTopologicalSort()) { - if (stage.getId().equals(stageId)) { - LOG.info("Removing Tasks for {} before they are scheduled to an executor", stage.getId()); - pendingTaskCollection.removeTasksAndDescendants(stage.getId()); - stage.getTaskIds().forEach(dstTaskId -> { - if (jobStateManager.getTaskState(dstTaskId).getStateMachine().getCurrentState() - != TaskState.State.COMPLETE) { - jobStateManager.onTaskStateChanged(dstTaskId, TaskState.State.FAILED_RECOVERABLE); - blockManagerMaster.onProducerTaskFailed(dstTaskId); - } - }); - break; - } - } - // the stage this task belongs to has become failed recoverable. - // it is a good point to start searching for another stage to schedule. - scheduleNextStage(stageId); - break; - // The task executed successfully but there is something wrong with the output store. + // TODO #50: Carefully retry tasks in the scheduler case OUTPUT_WRITE_FAILURE: - jobStateManager.onTaskStateChanged(taskId, newState); - LOG.info("Only the failed task will be retried."); - - // the stage this task belongs to has become failed recoverable. - // it is a good point to start searching for another stage to schedule. blockManagerMaster.onProducerTaskFailed(taskId); scheduleNextStage(stageId); break; case CONTAINER_FAILURE: - jobStateManager.onTaskStateChanged(taskId, newState); LOG.info("Only the failed task will be retried."); break; default: diff --git a/runtime/master/src/main/java/edu/snu/nemo/runtime/master/scheduler/Scheduler.java b/runtime/master/src/main/java/edu/snu/nemo/runtime/master/scheduler/Scheduler.java index c88e35a8..ebf79371 100644 --- a/runtime/master/src/main/java/edu/snu/nemo/runtime/master/scheduler/Scheduler.java +++ b/runtime/master/src/main/java/edu/snu/nemo/runtime/master/scheduler/Scheduler.java @@ -74,12 +74,12 @@ void scheduleJob(PhysicalPlan physicalPlan, * @param taskPutOnHold the ID of task that are put on hold. It is null otherwise. * @param failureCause for which the Task failed in the case of a recoverable failure. */ - void onTaskStateChanged(String executorId, - String taskId, - int attemptIdx, - TaskState.State newState, - @Nullable String taskPutOnHold, - TaskState.RecoverableFailureCause failureCause); + void onTaskStateReportFromExecutor(String executorId, + String taskId, + int attemptIdx, + TaskState.State newState, + @Nullable String taskPutOnHold, + TaskState.RecoverableFailureCause failureCause); /** * To be called when a job should be terminated. diff --git a/runtime/master/src/main/java/edu/snu/nemo/runtime/master/scheduler/SchedulerRunner.java b/runtime/master/src/main/java/edu/snu/nemo/runtime/master/scheduler/SchedulerRunner.java index a329c4b6..243ef075 100644 --- a/runtime/master/src/main/java/edu/snu/nemo/runtime/master/scheduler/SchedulerRunner.java +++ b/runtime/master/src/main/java/edu/snu/nemo/runtime/master/scheduler/SchedulerRunner.java @@ -17,7 +17,6 @@ import com.google.common.annotations.VisibleForTesting; import edu.snu.nemo.runtime.common.plan.Task; -import edu.snu.nemo.runtime.common.state.JobState; import edu.snu.nemo.runtime.common.state.TaskState; import edu.snu.nemo.runtime.master.JobStateManager; import edu.snu.nemo.runtime.master.resource.ExecutorRepresenter; @@ -170,7 +169,7 @@ public void run() { doScheduleStage(); } jobStateManagers.values().forEach(jobStateManager -> { - if (jobStateManager.getJobState().getStateMachine().getCurrentState() == JobState.State.COMPLETE) { + if (jobStateManager.isJobDone()) { LOG.info("{} is complete.", jobStateManager.getJobId()); } else { LOG.info("{} is incomplete.", jobStateManager.getJobId()); diff --git a/runtime/master/src/test/java/edu/snu/nemo/runtime/master/JobStateManagerTest.java b/runtime/master/src/test/java/edu/snu/nemo/runtime/master/JobStateManagerTest.java index db5b1d44..4aa2b0ac 100644 --- a/runtime/master/src/test/java/edu/snu/nemo/runtime/master/JobStateManagerTest.java +++ b/runtime/master/src/test/java/edu/snu/nemo/runtime/master/JobStateManagerTest.java @@ -15,8 +15,6 @@ */ package edu.snu.nemo.runtime.master; -import edu.snu.nemo.common.ir.edge.IREdge; -import edu.snu.nemo.common.ir.vertex.IRVertex; import edu.snu.nemo.conf.JobConf; import edu.snu.nemo.runtime.common.RuntimeIdGenerator; import edu.snu.nemo.runtime.common.message.MessageEnvironment; @@ -25,12 +23,9 @@ import edu.snu.nemo.runtime.common.plan.PhysicalPlan; import edu.snu.nemo.runtime.common.plan.PhysicalPlanGenerator; import edu.snu.nemo.runtime.common.plan.Stage; -import edu.snu.nemo.runtime.common.plan.StageEdge; import edu.snu.nemo.runtime.common.state.JobState; import edu.snu.nemo.runtime.common.state.StageState; import edu.snu.nemo.runtime.common.state.TaskState; -import edu.snu.nemo.common.dag.DAG; -import edu.snu.nemo.common.dag.DAGBuilder; import edu.snu.nemo.runtime.plangenerator.TestPlanGenerator; import org.apache.reef.tang.Injector; import org.apache.reef.tang.Tang; @@ -41,10 +36,9 @@ import org.powermock.modules.junit4.PowerMockRunner; import java.util.List; -import java.util.Map; import java.util.concurrent.*; +import java.util.stream.Collectors; -import static junit.framework.TestCase.assertTrue; import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertFalse; import static org.mockito.Mockito.mock; @@ -56,14 +50,12 @@ @PrepareForTest(MetricMessageHandler.class) public final class JobStateManagerTest { private static final int MAX_SCHEDULE_ATTEMPT = 2; - private DAGBuilder<IRVertex, IREdge> irDAGBuilder; private BlockManagerMaster blockManagerMaster; private MetricMessageHandler metricMessageHandler; private PhysicalPlanGenerator physicalPlanGenerator; @Before public void setUp() throws Exception { - irDAGBuilder = new DAGBuilder<>(); final LocalMessageDispatcher messageDispatcher = new LocalMessageDispatcher(); final LocalMessageEnvironment messageEnvironment = new LocalMessageEnvironment(MessageEnvironment.MASTER_COMMUNICATION_ID, messageDispatcher); @@ -92,23 +84,18 @@ public void testPhysicalPlanStateChanges() throws Exception { for (int stageIdx = 0; stageIdx < stageList.size(); stageIdx++) { final Stage stage = stageList.get(stageIdx); - jobStateManager.onStageStateChanged(stage.getId(), StageState.State.EXECUTING); final List<String> taskIds = stage.getTaskIds(); taskIds.forEach(taskId -> { jobStateManager.onTaskStateChanged(taskId, TaskState.State.EXECUTING); jobStateManager.onTaskStateChanged(taskId, TaskState.State.COMPLETE); if (RuntimeIdGenerator.getIndexFromTaskId(taskId) == taskIds.size() - 1) { - assertTrue(jobStateManager.checkStageCompletion(stage.getId())); + assertEquals(StageState.State.COMPLETE, jobStateManager.getStageState(stage.getId())); } }); - final Map<String, TaskState> taskStateMap = jobStateManager.getIdToTaskStates(); - taskIds.forEach(taskId -> { - assertEquals(taskStateMap.get(taskId).getStateMachine().getCurrentState(), - TaskState.State.COMPLETE); - }); + taskIds.forEach(taskId -> assertEquals(jobStateManager.getTaskState(taskId), TaskState.State.COMPLETE)); if (stageIdx == stageList.size() - 1) { - assertEquals(jobStateManager.getJobState().getStateMachine().getCurrentState(), JobState.State.COMPLETE); + assertEquals(jobStateManager.getJobState(), JobState.State.COMPLETE); } } } @@ -116,26 +103,28 @@ public void testPhysicalPlanStateChanges() throws Exception { /** * Test whether the methods waiting finish of job works properly. */ - @Test(timeout = 1000) - public void testWaitUntilFinish() { - // Create a JobStateManager of an empty dag. - final DAG<IRVertex, IREdge> irDAG = irDAGBuilder.build(); - final DAG<Stage, StageEdge> physicalDAG = irDAG.convert(physicalPlanGenerator); - final JobStateManager jobStateManager = new JobStateManager( - new PhysicalPlan("TestPlan", physicalDAG), - blockManagerMaster, metricMessageHandler, MAX_SCHEDULE_ATTEMPT); + @Test(timeout = 2000) + public void testWaitUntilFinish() throws Exception { + final PhysicalPlan physicalPlan = + TestPlanGenerator.generatePhysicalPlan(TestPlanGenerator.PlanType.TwoVerticesJoined, false); + final JobStateManager jobStateManager = + new JobStateManager(physicalPlan, blockManagerMaster, metricMessageHandler, MAX_SCHEDULE_ATTEMPT); - assertFalse(jobStateManager.checkJobTermination()); + assertFalse(jobStateManager.isJobDone()); // Wait for the job to finish and check the job state. // It have to return EXECUTING state after timeout. - JobState state = jobStateManager.waitUntilFinish(100, TimeUnit.MILLISECONDS); - assertEquals(state.getStateMachine().getCurrentState(), JobState.State.EXECUTING); + final JobState.State executingState = jobStateManager.waitUntilFinish(100, TimeUnit.MILLISECONDS); + assertEquals(JobState.State.EXECUTING, executingState); // Complete the job and check the result again. // It have to return COMPLETE. - jobStateManager.onJobStateChanged(JobState.State.COMPLETE); - state = jobStateManager.waitUntilFinish(); - assertEquals(state.getStateMachine().getCurrentState(), JobState.State.COMPLETE); + final List<String> tasks = physicalPlan.getStageDAG().getTopologicalSort().stream() + .flatMap(stage -> stage.getTaskIds().stream()) + .collect(Collectors.toList()); + tasks.forEach(taskId -> jobStateManager.onTaskStateChanged(taskId, TaskState.State.EXECUTING)); + tasks.forEach(taskId -> jobStateManager.onTaskStateChanged(taskId, TaskState.State.COMPLETE)); + final JobState.State completedState = jobStateManager.waitUntilFinish(); + assertEquals(JobState.State.COMPLETE, completedState); } } diff --git a/runtime/master/src/test/java/edu/snu/nemo/runtime/master/scheduler/BatchSingleJobSchedulerTest.java b/runtime/master/src/test/java/edu/snu/nemo/runtime/master/scheduler/BatchSingleJobSchedulerTest.java index 2d1ff2ea..f0f5d191 100644 --- a/runtime/master/src/test/java/edu/snu/nemo/runtime/master/scheduler/BatchSingleJobSchedulerTest.java +++ b/runtime/master/src/test/java/edu/snu/nemo/runtime/master/scheduler/BatchSingleJobSchedulerTest.java @@ -158,8 +158,7 @@ private void scheduleAndCheckJobTermination(final PhysicalPlan plan) throws Inje LOG.debug("Checking that all stages of ScheduleGroup {} enter the executing state", scheduleGroupIdx); stages.forEach(stage -> { - while (jobStateManager.getStageState(stage.getId()).getStateMachine().getCurrentState() - != StageState.State.EXECUTING) { + while (jobStateManager.getStageState(stage.getId()) != StageState.State.EXECUTING) { } }); @@ -171,9 +170,9 @@ private void scheduleAndCheckJobTermination(final PhysicalPlan plan) throws Inje } LOG.debug("Waiting for job termination after sending stage completion events"); - while (!jobStateManager.checkJobTermination()) { + while (!jobStateManager.isJobDone()) { } - assertTrue(jobStateManager.checkJobTermination()); + assertTrue(jobStateManager.isJobDone()); } private List<Stage> filterStagesWithAScheduleGroupIndex( diff --git a/runtime/master/src/test/java/edu/snu/nemo/runtime/master/scheduler/FaultToleranceTest.java b/runtime/master/src/test/java/edu/snu/nemo/runtime/master/scheduler/FaultToleranceTest.java index 96c65eb2..d2ec9629 100644 --- a/runtime/master/src/test/java/edu/snu/nemo/runtime/master/scheduler/FaultToleranceTest.java +++ b/runtime/master/src/test/java/edu/snu/nemo/runtime/master/scheduler/FaultToleranceTest.java @@ -48,8 +48,6 @@ import java.util.function.Function; import static edu.snu.nemo.runtime.common.state.StageState.State.COMPLETE; -import static edu.snu.nemo.runtime.common.state.StageState.State.EXECUTING; -import static junit.framework.TestCase.assertFalse; import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertTrue; import static org.mockito.Mockito.mock; @@ -104,7 +102,7 @@ private Scheduler setUpScheduler(final boolean useMockSchedulerRunner) throws In /** * Tests fault tolerance after a container removal. */ - @Test(timeout=10000) + @Test(timeout=5000) public void testContainerRemoval() throws Exception { final ActiveContext activeContext = mock(ActiveContext.class); Mockito.doThrow(new RuntimeException()).when(activeContext).close(); @@ -139,7 +137,7 @@ public void testContainerRemoval() throws Exception { if (stage.getScheduleGroupIndex() == 0 || stage.getScheduleGroupIndex() == 1) { // There are 3 executors, each of capacity 2, and there are 6 Tasks in ScheduleGroup 0 and 1. - SchedulerTestUtil.mockSchedulerRunner(pendingTaskCollection, schedulingPolicy, jobStateManager, + SchedulerTestUtil.mockSchedulingBySchedulerRunner(pendingTaskCollection, schedulingPolicy, jobStateManager, executorRegistry, false); assertTrue(pendingTaskCollection.isEmpty()); stage.getTaskIds().forEach(taskId -> @@ -148,18 +146,22 @@ public void testContainerRemoval() throws Exception { } else if (stage.getScheduleGroupIndex() == 2) { scheduler.onExecutorRemoved("a3"); // There are 2 executors, each of capacity 2, and there are 2 Tasks in ScheduleGroup 2. - SchedulerTestUtil.mockSchedulerRunner(pendingTaskCollection, schedulingPolicy, jobStateManager, + SchedulerTestUtil.mockSchedulingBySchedulerRunner(pendingTaskCollection, schedulingPolicy, jobStateManager, executorRegistry, false); // Due to round robin scheduling, "a2" is assured to have a running Task. scheduler.onExecutorRemoved("a2"); - while (jobStateManager.getStageState(stage.getId()).getStateMachine().getCurrentState() != EXECUTING) { + // Re-schedule + SchedulerTestUtil.mockSchedulingBySchedulerRunner(pendingTaskCollection, schedulingPolicy, jobStateManager, + executorRegistry, false); - } - assertEquals(jobStateManager.getAttemptCountForStage(stage.getId()), 2); + final Optional<Integer> maxTaskAttempt = stage.getTaskIds().stream() + .map(jobStateManager::getTaskAttempt).max(Integer::compareTo); + assertTrue(maxTaskAttempt.isPresent()); + assertEquals(2, (int) maxTaskAttempt.get()); - SchedulerTestUtil.mockSchedulerRunner(pendingTaskCollection, schedulingPolicy, jobStateManager, + SchedulerTestUtil.mockSchedulingBySchedulerRunner(pendingTaskCollection, schedulingPolicy, jobStateManager, executorRegistry, false); assertTrue(pendingTaskCollection.isEmpty()); stage.getTaskIds().forEach(taskId -> @@ -168,7 +170,7 @@ public void testContainerRemoval() throws Exception { } else if (stage.getScheduleGroupIndex() == 3) { // There are 1 executors, each of capacity 2, and there are 2 Tasks in ScheduleGroup 3. // Schedule only the first Task - SchedulerTestUtil.mockSchedulerRunner(pendingTaskCollection, schedulingPolicy, jobStateManager, + SchedulerTestUtil.mockSchedulingBySchedulerRunner(pendingTaskCollection, schedulingPolicy, jobStateManager, executorRegistry, true); } else { throw new RuntimeException(String.format("Unexpected ScheduleGroupIndex: %d", @@ -180,7 +182,7 @@ public void testContainerRemoval() throws Exception { /** * Tests fault tolerance after an output write failure. */ - @Test(timeout=10000) + @Test(timeout=5000) public void testOutputFailure() throws Exception { final ActiveContext activeContext = mock(ActiveContext.class); Mockito.doThrow(new RuntimeException()).when(activeContext).close(); @@ -213,7 +215,7 @@ public void testOutputFailure() throws Exception { if (stage.getScheduleGroupIndex() == 0 || stage.getScheduleGroupIndex() == 1) { // There are 3 executors, each of capacity 2, and there are 6 Tasks in ScheduleGroup 0 and 1. - SchedulerTestUtil.mockSchedulerRunner(pendingTaskCollection, schedulingPolicy, jobStateManager, + SchedulerTestUtil.mockSchedulingBySchedulerRunner(pendingTaskCollection, schedulingPolicy, jobStateManager, executorRegistry, false); assertTrue(pendingTaskCollection.isEmpty()); stage.getTaskIds().forEach(taskId -> @@ -221,7 +223,7 @@ public void testOutputFailure() throws Exception { taskId, TaskState.State.COMPLETE, 1)); } else if (stage.getScheduleGroupIndex() == 2) { // There are 3 executors, each of capacity 2, and there are 2 Tasks in ScheduleGroup 2. - SchedulerTestUtil.mockSchedulerRunner(pendingTaskCollection, schedulingPolicy, jobStateManager, + SchedulerTestUtil.mockSchedulingBySchedulerRunner(pendingTaskCollection, schedulingPolicy, jobStateManager, executorRegistry, false); assertTrue(pendingTaskCollection.isEmpty()); stage.getTaskIds().forEach(taskId -> @@ -229,16 +231,18 @@ public void testOutputFailure() throws Exception { taskId, TaskState.State.FAILED_RECOVERABLE, 1, TaskState.RecoverableFailureCause.OUTPUT_WRITE_FAILURE)); - while (jobStateManager.getStageState(stage.getId()).getStateMachine().getCurrentState() != EXECUTING) { + // Re-schedule + SchedulerTestUtil.mockSchedulingBySchedulerRunner(pendingTaskCollection, schedulingPolicy, jobStateManager, + executorRegistry, false); - } + final Optional<Integer> maxTaskAttempt = stage.getTaskIds().stream() + .map(jobStateManager::getTaskAttempt).max(Integer::compareTo); + assertTrue(maxTaskAttempt.isPresent()); + assertEquals(2, (int) maxTaskAttempt.get()); - assertEquals(3, jobStateManager.getAttemptCountForStage(stage.getId())); - assertFalse(pendingTaskCollection.isEmpty()); - stage.getTaskIds().forEach(taskId -> { - assertEquals(jobStateManager.getTaskState(taskId).getStateMachine().getCurrentState(), - TaskState.State.READY); - }); + assertTrue(pendingTaskCollection.isEmpty()); + stage.getTaskIds().forEach(taskId -> + assertEquals(TaskState.State.EXECUTING, jobStateManager.getTaskState(taskId))); } } } @@ -246,7 +250,7 @@ public void testOutputFailure() throws Exception { /** * Tests fault tolerance after an input read failure. */ - @Test(timeout=10000) + @Test(timeout=5000) public void testInputReadFailure() throws Exception { final ActiveContext activeContext = mock(ActiveContext.class); Mockito.doThrow(new RuntimeException()).when(activeContext).close(); @@ -279,7 +283,7 @@ public void testInputReadFailure() throws Exception { if (stage.getScheduleGroupIndex() == 0 || stage.getScheduleGroupIndex() == 1) { // There are 3 executors, each of capacity 2, and there are 6 Tasks in ScheduleGroup 0 and 1. - SchedulerTestUtil.mockSchedulerRunner(pendingTaskCollection, schedulingPolicy, jobStateManager, + SchedulerTestUtil.mockSchedulingBySchedulerRunner(pendingTaskCollection, schedulingPolicy, jobStateManager, executorRegistry, false); assertTrue(pendingTaskCollection.isEmpty()); stage.getTaskIds().forEach(taskId -> @@ -287,7 +291,7 @@ public void testInputReadFailure() throws Exception { taskId, TaskState.State.COMPLETE, 1)); } else if (stage.getScheduleGroupIndex() == 2) { // There are 3 executors, each of capacity 2, and there are 2 Tasks in ScheduleGroup 2. - SchedulerTestUtil.mockSchedulerRunner(pendingTaskCollection, schedulingPolicy, jobStateManager, + SchedulerTestUtil.mockSchedulingBySchedulerRunner(pendingTaskCollection, schedulingPolicy, jobStateManager, executorRegistry, false); stage.getTaskIds().forEach(taskId -> @@ -295,15 +299,17 @@ public void testInputReadFailure() throws Exception { taskId, TaskState.State.FAILED_RECOVERABLE, 1, TaskState.RecoverableFailureCause.INPUT_READ_FAILURE)); - while (jobStateManager.getStageState(stage.getId()).getStateMachine().getCurrentState() != EXECUTING) { + // Re-schedule + SchedulerTestUtil.mockSchedulingBySchedulerRunner(pendingTaskCollection, schedulingPolicy, jobStateManager, + executorRegistry, false); - } + final Optional<Integer> maxTaskAttempt = stage.getTaskIds().stream() + .map(jobStateManager::getTaskAttempt).max(Integer::compareTo); + assertTrue(maxTaskAttempt.isPresent()); + assertEquals(2, (int) maxTaskAttempt.get()); - assertEquals(2, jobStateManager.getAttemptCountForStage(stage.getId())); - stage.getTaskIds().forEach(taskId -> { - assertEquals(jobStateManager.getTaskState(taskId).getStateMachine().getCurrentState(), - TaskState.State.READY); - }); + stage.getTaskIds().forEach(taskId -> + assertEquals(TaskState.State.EXECUTING, jobStateManager.getTaskState(taskId))); } } } @@ -331,12 +337,12 @@ public void testTaskReexecutionForFailure() throws Exception { final List<Stage> dagOf4Stages = plan.getStageDAG().getTopologicalSort(); int executorIdIndex = 1; - float removalChance = 0.7f; // Out of 1.0 + float removalChance = 0.5f; // Out of 1.0 final Random random = new Random(0); // Deterministic seed. for (final Stage stage : dagOf4Stages) { - while (jobStateManager.getStageState(stage.getId()).getStateMachine().getCurrentState() != COMPLETE) { + while (jobStateManager.getStageState(stage.getId()) != COMPLETE) { // By chance, remove or add executor if (isTrueByChance(random, removalChance)) { // REMOVE EXECUTOR @@ -370,7 +376,7 @@ public void testTaskReexecutionForFailure() throws Exception { } } } - assertTrue(jobStateManager.checkJobTermination()); + assertTrue(jobStateManager.isJobDone()); } private boolean isTrueByChance(final Random random, final float chance) { diff --git a/runtime/master/src/test/java/edu/snu/nemo/runtime/master/scheduler/SchedulerTestUtil.java b/runtime/master/src/test/java/edu/snu/nemo/runtime/master/scheduler/SchedulerTestUtil.java index 55fbaef8..cb15f7ad 100644 --- a/runtime/master/src/test/java/edu/snu/nemo/runtime/master/scheduler/SchedulerTestUtil.java +++ b/runtime/master/src/test/java/edu/snu/nemo/runtime/master/scheduler/SchedulerTestUtil.java @@ -41,20 +41,20 @@ static void completeStage(final JobStateManager jobStateManager, final int attemptIdx) { // Loop until the stage completes. while (true) { - final Enum stageState = jobStateManager.getStageState(stage.getId()).getStateMachine().getCurrentState(); + final StageState.State stageState = jobStateManager.getStageState(stage.getId()); if (StageState.State.COMPLETE == stageState) { // Stage has completed, so we break out of the loop. break; } else if (StageState.State.EXECUTING == stageState) { stage.getTaskIds().forEach(taskId -> { - final Enum tgState = jobStateManager.getTaskState(taskId).getStateMachine().getCurrentState(); - if (TaskState.State.EXECUTING == tgState) { + final TaskState.State taskState = jobStateManager.getTaskState(taskId); + if (TaskState.State.EXECUTING == taskState) { sendTaskStateEventToScheduler(scheduler, executorRegistry, taskId, TaskState.State.COMPLETE, attemptIdx, null); - } else if (TaskState.State.READY == tgState || TaskState.State.COMPLETE == tgState) { + } else if (TaskState.State.READY == taskState || TaskState.State.COMPLETE == taskState) { // Skip READY (try in the next loop and see if it becomes EXECUTING) and COMPLETE. } else { - throw new IllegalStateException(tgState.toString()); + throw new IllegalStateException(taskState.toString()); } }); } else if (StageState.State.READY == stageState) { @@ -88,7 +88,7 @@ static void sendTaskStateEventToScheduler(final Scheduler scheduler, break; } } - scheduler.onTaskStateChanged(scheduledExecutor.getExecutorId(), taskId, attemptIdx, + scheduler.onTaskStateReportFromExecutor(scheduledExecutor.getExecutorId(), taskId, attemptIdx, newState, null, cause); } @@ -100,17 +100,17 @@ static void sendTaskStateEventToScheduler(final Scheduler scheduler, sendTaskStateEventToScheduler(scheduler, executorRegistry, taskId, newState, attemptIdx, null); } - static void mockSchedulerRunner(final PendingTaskCollection pendingTaskCollection, - final SchedulingPolicy schedulingPolicy, - final JobStateManager jobStateManager, - final ExecutorRegistry executorRegistry, - final boolean isPartialSchedule) { + static void mockSchedulingBySchedulerRunner(final PendingTaskCollection pendingTaskCollection, + final SchedulingPolicy schedulingPolicy, + final JobStateManager jobStateManager, + final ExecutorRegistry executorRegistry, + final boolean scheduleOnlyTheFirstStage) { final SchedulerRunner schedulerRunner = new SchedulerRunner(schedulingPolicy, pendingTaskCollection, executorRegistry); schedulerRunner.scheduleJob(jobStateManager); while (!pendingTaskCollection.isEmpty()) { schedulerRunner.doScheduleStage(); - if (isPartialSchedule) { + if (scheduleOnlyTheFirstStage) { // Schedule only the first stage break; } diff --git a/tests/src/test/java/edu/snu/nemo/tests/client/ClientEndpointTest.java b/tests/src/test/java/edu/snu/nemo/tests/client/ClientEndpointTest.java index 0d60a5d4..c962ad33 100644 --- a/tests/src/test/java/edu/snu/nemo/tests/client/ClientEndpointTest.java +++ b/tests/src/test/java/edu/snu/nemo/tests/client/ClientEndpointTest.java @@ -22,6 +22,7 @@ import edu.snu.nemo.common.dag.DAGBuilder; import edu.snu.nemo.common.ir.edge.IREdge; import edu.snu.nemo.common.ir.vertex.IRVertex; +import edu.snu.nemo.common.test.EmptyComponents; import edu.snu.nemo.conf.JobConf; import edu.snu.nemo.runtime.common.message.MessageEnvironment; import edu.snu.nemo.runtime.common.message.local.LocalMessageDispatcher; @@ -31,9 +32,11 @@ import edu.snu.nemo.runtime.common.plan.Stage; import edu.snu.nemo.runtime.common.plan.StageEdge; import edu.snu.nemo.runtime.common.state.JobState; +import edu.snu.nemo.runtime.common.state.TaskState; import edu.snu.nemo.runtime.master.MetricMessageHandler; import edu.snu.nemo.runtime.master.BlockManagerMaster; import edu.snu.nemo.runtime.master.JobStateManager; +import edu.snu.nemo.runtime.plangenerator.TestPlanGenerator; import org.apache.reef.tang.Injector; import org.apache.reef.tang.Tang; import org.junit.Test; @@ -41,7 +44,9 @@ import org.powermock.core.classloader.annotations.PrepareForTest; import org.powermock.modules.junit4.PowerMockRunner; +import java.util.List; import java.util.concurrent.TimeUnit; +import java.util.stream.Collectors; import static org.junit.Assert.assertEquals; import static org.mockito.Matchers.any; @@ -68,22 +73,17 @@ public void testState() throws Exception { // Wait for connection but not connected. assertEquals(clientEndpoint.waitUntilJobFinish(100, TimeUnit.MILLISECONDS), JobState.State.READY); - // Create a JobStateManager of an empty dag and create a DriverEndpoint with it. - final DAGBuilder<IRVertex, IREdge> irDagBuilder = new DAGBuilder<>(); - final DAG<IRVertex, IREdge> irDAG = irDagBuilder.build(); - final Injector injector = Tang.Factory.getTang().newInjector(); - injector.bindVolatileParameter(JobConf.DAGDirectory.class, ""); - final PhysicalPlanGenerator physicalPlanGenerator = injector.getInstance(PhysicalPlanGenerator.class); - final DAG<Stage, StageEdge> physicalDAG = irDAG.convert(physicalPlanGenerator); - + // Create a JobStateManager of a dag and create a DriverEndpoint with it. + final PhysicalPlan physicalPlan = + TestPlanGenerator.generatePhysicalPlan(TestPlanGenerator.PlanType.TwoVerticesJoined, false); final LocalMessageDispatcher messageDispatcher = new LocalMessageDispatcher(); final LocalMessageEnvironment messageEnvironment = new LocalMessageEnvironment(MessageEnvironment.MASTER_COMMUNICATION_ID, messageDispatcher); + final Injector injector = Tang.Factory.getTang().newInjector(); injector.bindVolatileInstance(MessageEnvironment.class, messageEnvironment); final BlockManagerMaster pmm = injector.getInstance(BlockManagerMaster.class); - final JobStateManager jobStateManager = new JobStateManager( - new PhysicalPlan("TestPlan", physicalDAG), - pmm, metricMessageHandler, MAX_SCHEDULE_ATTEMPT); + final JobStateManager jobStateManager = + new JobStateManager(physicalPlan, pmm, metricMessageHandler, MAX_SCHEDULE_ATTEMPT); final DriverEndpoint driverEndpoint = new DriverEndpoint(jobStateManager, clientEndpoint); @@ -94,8 +94,12 @@ public void testState() throws Exception { assertEquals(clientEndpoint.waitUntilJobFinish(100, TimeUnit.MILLISECONDS), JobState.State.EXECUTING); // Check finish. - jobStateManager.onJobStateChanged(JobState.State.COMPLETE); - assertEquals(clientEndpoint.waitUntilJobFinish(), JobState.State.COMPLETE); + final List<String> tasks = physicalPlan.getStageDAG().getTopologicalSort().stream() + .flatMap(stage -> stage.getTaskIds().stream()) + .collect(Collectors.toList()); + tasks.forEach(taskId -> jobStateManager.onTaskStateChanged(taskId, TaskState.State.EXECUTING)); + tasks.forEach(taskId -> jobStateManager.onTaskStateChanged(taskId, TaskState.State.COMPLETE)); + assertEquals(JobState.State.COMPLETE, clientEndpoint.waitUntilJobFinish()); } /** ---------------------------------------------------------------- This is an automated message from the Apache Git Service. To respond to the message, please log on GitHub and use the URL above to go to the specific comment. For queries about this service, please contact Infrastructure at: us...@infra.apache.org With regards, Apache Git Services