This is an automated email from the ASF dual-hosted git repository.

jeagles pushed a commit to branch branch-3.1
in repository https://gitbox.apache.org/repos/asf/hadoop.git


The following commit(s) were added to refs/heads/branch-3.1 by this push:
     new 6ad2f66  MAPREDUCE-7252. Handling 0 progress in SimpleExponential task 
runtime estimator
6ad2f66 is described below

commit 6ad2f662eb4995637c537fd1f23de9e246c023a1
Author: Ahmed Hussein <ahuss...@apache.org>
AuthorDate: Wed Jan 8 11:08:13 2020 -0600

    MAPREDUCE-7252. Handling 0 progress in SimpleExponential task runtime 
estimator
    
    Signed-off-by: Jonathan Eagles <jeag...@gmail.com>
    (cherry picked from commit cdd6efd3ab6917e30b4c5c7b261f61838901bb37)
---
 .../mapreduce/v2/app/speculate/DataStatistics.java |  28 +++--
 .../SimpleExponentialTaskRuntimeEstimator.java     |  67 +++++++----
 .../forecast/SimpleExponentialSmoothing.java       | 131 +++++++++++++--------
 .../v2/app/speculate/forecast/package-info.java    |  20 ++++
 .../org/apache/hadoop/mapreduce/v2/app/MRApp.java  |  42 ++++++-
 .../v2/TestSpeculativeExecutionWithMRApp.java      | 116 ++++++++++++++++--
 6 files changed, 308 insertions(+), 96 deletions(-)

diff --git 
a/hadoop-mapreduce-project/hadoop-mapreduce-client/hadoop-mapreduce-client-app/src/main/java/org/apache/hadoop/mapreduce/v2/app/speculate/DataStatistics.java
 
b/hadoop-mapreduce-project/hadoop-mapreduce-client/hadoop-mapreduce-client-app/src/main/java/org/apache/hadoop/mapreduce/v2/app/speculate/DataStatistics.java
index 9f1c122..036eb45 100644
--- 
a/hadoop-mapreduce-project/hadoop-mapreduce-client/hadoop-mapreduce-client-app/src/main/java/org/apache/hadoop/mapreduce/v2/app/speculate/DataStatistics.java
+++ 
b/hadoop-mapreduce-project/hadoop-mapreduce-client/hadoop-mapreduce-client-app/src/main/java/org/apache/hadoop/mapreduce/v2/app/speculate/DataStatistics.java
@@ -18,6 +18,11 @@
 package org.apache.hadoop.mapreduce.v2.app.speculate;
 
 public class DataStatistics {
+
+  /**
+   * factor used to calculate confidence interval within 95%.
+   */
+  private static final double DEFAULT_CI_FACTOR = 1.96;
   private int count = 0;
   private double sum = 0;
   private double sumSquares = 0;
@@ -25,25 +30,26 @@ public class DataStatistics {
   public DataStatistics() {
   }
 
-  public DataStatistics(double initNum) {
+  public DataStatistics(final double initNum) {
     this.count = 1;
     this.sum = initNum;
     this.sumSquares = initNum * initNum;
   }
 
-  public synchronized void add(double newNum) {
+  public synchronized void add(final double newNum) {
     this.count++;
     this.sum += newNum;
     this.sumSquares += newNum * newNum;
   }
 
-  public synchronized void updateStatistics(double old, double update) {
-       this.sum += update - old;
-       this.sumSquares += (update * update) - (old * old);
+  public synchronized void updateStatistics(final double old,
+      final double update) {
+    this.sum += update - old;
+    this.sumSquares += (update * update) - (old * old);
   }
 
   public synchronized double mean() {
-    return count == 0 ? 0.0 : sum/count;
+    return count == 0 ? 0.0 : sum / count;
   }
 
   public synchronized double var() {
@@ -52,14 +58,14 @@ public class DataStatistics {
       return 0.0;
     }
     double mean = mean();
-    return Math.max((sumSquares/count) - mean * mean, 0.0d);
+    return Math.max((sumSquares / count) - mean * mean, 0.0d);
   }
 
   public synchronized double std() {
     return Math.sqrt(this.var());
   }
 
-  public synchronized double outlier(float sigma) {
+  public synchronized double outlier(final float sigma) {
     if (count != 0.0) {
       return mean() + std() * sigma;
     }
@@ -78,10 +84,12 @@ public class DataStatistics {
    * @return the mean value adding 95% confidence interval
    */
   public synchronized double meanCI() {
-    if (count <= 1) return 0.0;
+    if (count <= 1) {
+      return 0.0;
+    }
     double currMean = mean();
     double currStd = std();
-    return currMean + (1.96 * currStd / Math.sqrt(count));
+    return currMean + (DEFAULT_CI_FACTOR * currStd / Math.sqrt(count));
   }
 
   public String toString() {
diff --git 
a/hadoop-mapreduce-project/hadoop-mapreduce-client/hadoop-mapreduce-client-app/src/main/java/org/apache/hadoop/mapreduce/v2/app/speculate/SimpleExponentialTaskRuntimeEstimator.java
 
b/hadoop-mapreduce-project/hadoop-mapreduce-client/hadoop-mapreduce-client-app/src/main/java/org/apache/hadoop/mapreduce/v2/app/speculate/SimpleExponentialTaskRuntimeEstimator.java
index f244b20..2838916 100644
--- 
a/hadoop-mapreduce-project/hadoop-mapreduce-client/hadoop-mapreduce-client-app/src/main/java/org/apache/hadoop/mapreduce/v2/app/speculate/SimpleExponentialTaskRuntimeEstimator.java
+++ 
b/hadoop-mapreduce-project/hadoop-mapreduce-client/hadoop-mapreduce-client-app/src/main/java/org/apache/hadoop/mapreduce/v2/app/speculate/SimpleExponentialTaskRuntimeEstimator.java
@@ -33,7 +33,22 @@ import 
org.apache.hadoop.mapreduce.v2.app.speculate.forecast.SimpleExponentialSm
  * A task Runtime Estimator based on exponential smoothing.
  */
 public class SimpleExponentialTaskRuntimeEstimator extends StartEndTimesBase {
-  private final static long DEFAULT_ESTIMATE_RUNTIME = -1L;
+
+  /**
+   * The default value returned by the estimator when no records exist.
+   */
+  private static final long DEFAULT_ESTIMATE_RUNTIME = -1L;
+
+  /**
+   * Given a forecast of value 0.0, it is getting replaced by the default value
+   * to avoid division by 0.
+   */
+  private static final double DEFAULT_PROGRESS_VALUE = 1E-10;
+
+  /**
+   * Factor used to calculate the confidence interval.
+   */
+  private static final double CONFIDENCE_INTERVAL_FACTOR = 0.25;
 
   /**
    * Constant time used to calculate the smoothing exponential factor.
@@ -53,11 +68,15 @@ public class SimpleExponentialTaskRuntimeEstimator extends 
StartEndTimesBase {
    */
   private long stagnatedWindow;
 
+  /**
+   * A map of TA Id to the statistic model of smooth exponential.
+   */
   private final ConcurrentMap<TaskAttemptId,
       AtomicReference<SimpleExponentialSmoothing>>
       estimates = new ConcurrentHashMap<>();
 
-  private SimpleExponentialSmoothing getForecastEntry(TaskAttemptId attemptID) 
{
+  private SimpleExponentialSmoothing getForecastEntry(
+      final TaskAttemptId attemptID) {
     AtomicReference<SimpleExponentialSmoothing> entryRef = estimates
         .get(attemptID);
     if (entryRef == null) {
@@ -66,13 +85,13 @@ public class SimpleExponentialTaskRuntimeEstimator extends 
StartEndTimesBase {
     return entryRef.get();
   }
 
-  private void incorporateReading(TaskAttemptId attemptID,
-      float newRawData, long newTimeStamp) {
+  private void incorporateReading(final TaskAttemptId attemptID,
+      final float newRawData, final long newTimeStamp) {
     SimpleExponentialSmoothing foreCastEntry = getForecastEntry(attemptID);
     if (foreCastEntry == null) {
       Long tStartTime = startTimes.get(attemptID);
       // skip if the startTime is not set yet
-      if(tStartTime == null) {
+      if (tStartTime == null) {
         return;
       }
       estimates.putIfAbsent(attemptID,
@@ -86,7 +105,8 @@ public class SimpleExponentialTaskRuntimeEstimator extends 
StartEndTimesBase {
   }
 
   @Override
-  public void contextualize(Configuration conf, AppContext context) {
+  public void contextualize(final Configuration conf,
+      final AppContext context) {
     super.contextualize(conf, context);
 
     constTime
@@ -103,18 +123,16 @@ public class SimpleExponentialTaskRuntimeEstimator 
extends StartEndTimesBase {
   }
 
   @Override
-  public long estimatedRuntime(TaskAttemptId id) {
+  public long estimatedRuntime(final TaskAttemptId id) {
     SimpleExponentialSmoothing foreCastEntry = getForecastEntry(id);
     if (foreCastEntry == null) {
       return DEFAULT_ESTIMATE_RUNTIME;
     }
-    // TODO: What should we do when estimate is zero
-    double remainingWork = Math.min(1.0, 1.0 - foreCastEntry.getRawData());
-    double forecast = foreCastEntry.getForecast();
-    if (forecast <= 0.0) {
-      return DEFAULT_ESTIMATE_RUNTIME;
-    }
-    long remainingTime = (long)(remainingWork / forecast);
+    double remainingWork = Math
+        .max(0.0, Math.min(1.0, 1.0 - foreCastEntry.getRawData()));
+    double forecast = Math
+        .max(DEFAULT_PROGRESS_VALUE, foreCastEntry.getForecast());
+    long remainingTime = (long) (remainingWork / forecast);
     long estimatedRuntime = remainingTime
         + foreCastEntry.getTimeStamp()
         - foreCastEntry.getStartTime();
@@ -122,30 +140,32 @@ public class SimpleExponentialTaskRuntimeEstimator 
extends StartEndTimesBase {
   }
 
   @Override
-  public long estimatedNewAttemptRuntime(TaskId id) {
+  public long estimatedNewAttemptRuntime(final TaskId id) {
     DataStatistics statistics = dataStatisticsForTask(id);
 
     if (statistics == null) {
-      return -1L;
+      return DEFAULT_ESTIMATE_RUNTIME;
     }
 
     double statsMeanCI = statistics.meanCI();
     double expectedVal =
-        statsMeanCI + Math.min(statsMeanCI * 0.25, statistics.std() / 2);
-    return (long)(expectedVal);
+        statsMeanCI + Math.min(statsMeanCI * CONFIDENCE_INTERVAL_FACTOR,
+            statistics.std() / 2);
+    return (long) (expectedVal);
   }
 
   @Override
-  public boolean hasStagnatedProgress(TaskAttemptId id, long timeStamp) {
+  public boolean hasStagnatedProgress(final TaskAttemptId id,
+      final long timeStamp) {
     SimpleExponentialSmoothing foreCastEntry = getForecastEntry(id);
-    if(foreCastEntry == null) {
+    if (foreCastEntry == null) {
       return false;
     }
     return foreCastEntry.isDataStagnated(timeStamp);
   }
 
   @Override
-  public long runtimeEstimateVariance(TaskAttemptId id) {
+  public long runtimeEstimateVariance(final TaskAttemptId id) {
     SimpleExponentialSmoothing forecastEntry = getForecastEntry(id);
     if (forecastEntry == null) {
       return DEFAULT_ESTIMATE_RUNTIME;
@@ -154,12 +174,13 @@ public class SimpleExponentialTaskRuntimeEstimator 
extends StartEndTimesBase {
     if (forecastEntry.isDefaultForecast(forecast)) {
       return DEFAULT_ESTIMATE_RUNTIME;
     }
-    //TODO: What is the best way to measure variance in runtime
+    //TODO What is the best way to measure variance in runtime
     return 0L;
   }
 
   @Override
-  public void updateAttempt(TaskAttemptStatus status, long timestamp) {
+  public void updateAttempt(final TaskAttemptStatus status,
+      final long timestamp) {
     super.updateAttempt(status, timestamp);
     TaskAttemptId attemptID = status.id;
 
diff --git 
a/hadoop-mapreduce-project/hadoop-mapreduce-client/hadoop-mapreduce-client-app/src/main/java/org/apache/hadoop/mapreduce/v2/app/speculate/forecast/SimpleExponentialSmoothing.java
 
b/hadoop-mapreduce-project/hadoop-mapreduce-client/hadoop-mapreduce-client-app/src/main/java/org/apache/hadoop/mapreduce/v2/app/speculate/forecast/SimpleExponentialSmoothing.java
index e1ef7be..0e00068 100644
--- 
a/hadoop-mapreduce-project/hadoop-mapreduce-client/hadoop-mapreduce-client-app/src/main/java/org/apache/hadoop/mapreduce/v2/app/speculate/forecast/SimpleExponentialSmoothing.java
+++ 
b/hadoop-mapreduce-project/hadoop-mapreduce-client/hadoop-mapreduce-client-app/src/main/java/org/apache/hadoop/mapreduce/v2/app/speculate/forecast/SimpleExponentialSmoothing.java
@@ -24,108 +24,145 @@ import java.util.concurrent.atomic.AtomicReference;
  * Implementation of the static model for Simple exponential smoothing.
  */
 public class SimpleExponentialSmoothing {
-  public final static double DEFAULT_FORECAST = -1.0;
+  private static final double DEFAULT_FORECAST = -1.0;
   private final int kMinimumReads;
   private final long kStagnatedWindow;
   private final long startTime;
   private long timeConstant;
 
+  /**
+   * Holds reference to the current forecast record.
+   */
   private AtomicReference<ForecastRecord> forecastRefEntry;
 
-  public static SimpleExponentialSmoothing createForecast(long timeConstant,
-      int skipCnt, long stagnatedWindow, long timeStamp) {
+  public static SimpleExponentialSmoothing createForecast(
+      final long timeConstant,
+      final int skipCnt, final long stagnatedWindow, final long timeStamp) {
     return new SimpleExponentialSmoothing(timeConstant, skipCnt,
         stagnatedWindow, timeStamp);
   }
 
-  SimpleExponentialSmoothing(long ktConstant, int skipCnt,
-      long stagnatedWindow, long timeStamp) {
-    kMinimumReads = skipCnt;
-    kStagnatedWindow = stagnatedWindow;
+  SimpleExponentialSmoothing(final long ktConstant, final int skipCnt,
+      final long stagnatedWindow, final long timeStamp) {
+    this.kMinimumReads = skipCnt;
+    this.kStagnatedWindow = stagnatedWindow;
     this.timeConstant = ktConstant;
     this.startTime = timeStamp;
     this.forecastRefEntry = new AtomicReference<ForecastRecord>(null);
   }
 
   private class ForecastRecord {
-    private double alpha;
-    private long timeStamp;
-    private double sample;
-    private double rawData;
+    private final double alpha;
+    private final long timeStamp;
+    private final double sample;
+    private final double rawData;
     private double forecast;
-    private double sseError;
-    private long myIndex;
+    private final double sseError;
+    private final long myIndex;
+    private ForecastRecord prevRec;
 
-    ForecastRecord(double forecast, double rawData, long timeStamp) {
-      this(0.0, forecast, rawData, forecast, timeStamp, 0.0, 0);
+    ForecastRecord(final double currForecast, final double currRawData,
+        final long currTimeStamp) {
+      this(0.0, currForecast, currRawData, currForecast, currTimeStamp, 0.0, 
0);
     }
 
-    ForecastRecord(double alpha, double sample, double rawData,
-        double forecast, long timeStamp, double accError, long index) {
-      this.timeStamp = timeStamp;
-      this.alpha = alpha;
-      this.sseError = 0.0;
-      this.sample = sample;
-      this.forecast = forecast;
-      this.rawData = rawData;
+    ForecastRecord(final double alphaVal, final double currSample,
+        final double currRawData,
+        final double currForecast, final long currTimeStamp,
+        final double accError,
+        final long index) {
+      this.timeStamp = currTimeStamp;
+      this.alpha = alphaVal;
+      this.sample = currSample;
+      this.forecast = currForecast;
+      this.rawData = currRawData;
       this.sseError = accError;
       this.myIndex = index;
     }
 
-    private double preProcessRawData(double rData, long newTime) {
+    private ForecastRecord createForecastRecord(final double alphaVal,
+        final double currSample,
+        final double currRawData,
+        final double currForecast, final long currTimeStamp,
+        final double accError,
+        final long index,
+        final ForecastRecord prev) {
+      ForecastRecord forecastRec =
+          new ForecastRecord(alphaVal, currSample, currRawData, currForecast,
+              currTimeStamp, accError, index);
+      forecastRec.prevRec = prev;
+      return forecastRec;
+    }
+
+    private double preProcessRawData(final double rData, final long newTime) {
       return processRawData(this.rawData, this.timeStamp, rData, newTime);
     }
 
-    public ForecastRecord append(long newTimeStamp, double rData) {
-      if (this.timeStamp > newTimeStamp) {
+    public ForecastRecord append(final long newTimeStamp, final double rData) {
+      if (this.timeStamp >= newTimeStamp
+          && Double.compare(this.rawData, rData) >= 0) {
+        // progress reported twice. Do nothing.
         return this;
       }
-      double newSample = preProcessRawData(rData, newTimeStamp);
+      ForecastRecord refRecord = this;
+      if (newTimeStamp == this.timeStamp) {
+        // we need to restore old value if possible
+        if (this.prevRec != null) {
+          refRecord = this.prevRec;
+        }
+      }
+      double newSample = refRecord.preProcessRawData(rData, newTimeStamp);
       long deltaTime = this.timeStamp - newTimeStamp;
-      if (this.myIndex == kMinimumReads) {
+      if (refRecord.myIndex == kMinimumReads) {
         timeConstant = Math.max(timeConstant, newTimeStamp - startTime);
       }
       double smoothFactor =
           1 - Math.exp(((double) deltaTime) / timeConstant);
       double forecastVal =
-          smoothFactor * newSample + (1.0 - smoothFactor) * this.forecast;
+          smoothFactor * newSample + (1.0 - smoothFactor) * refRecord.forecast;
       double newSSEError =
-          this.sseError + Math.pow(newSample - this.forecast, 2);
-      return new ForecastRecord(smoothFactor, newSample, rData, forecastVal,
-          newTimeStamp, newSSEError, this.myIndex + 1);
+          refRecord.sseError + Math.pow(newSample - refRecord.forecast, 2);
+      return refRecord
+          .createForecastRecord(smoothFactor, newSample, rData, forecastVal,
+              newTimeStamp, newSSEError, refRecord.myIndex + 1, refRecord);
     }
-
   }
 
-  public boolean isDataStagnated(long timeStamp) {
+  /**
+   * checks if the task is hanging up.
+   * @param timeStamp current time of the scan.
+   * @return true if we have number of samples > kMinimumReads and the record
+   * timestamp has expired.
+   */
+  public boolean isDataStagnated(final long timeStamp) {
     ForecastRecord rec = forecastRefEntry.get();
-    if (rec != null && rec.myIndex <= kMinimumReads) {
-      return (rec.timeStamp + kStagnatedWindow) < timeStamp;
+    if (rec != null && rec.myIndex > kMinimumReads) {
+      return (rec.timeStamp + kStagnatedWindow) > timeStamp;
     }
     return false;
   }
 
-  static double processRawData(double oldRawData, long oldTime,
-      double newRawData, long newTime) {
+  static double processRawData(final double oldRawData, final long oldTime,
+      final double newRawData, final long newTime) {
     double rate = (newRawData - oldRawData) / (newTime - oldTime);
     return rate;
   }
 
-  public void incorporateReading(long timeStamp, double rawData) {
+  public void incorporateReading(final long timeStamp,
+      final double currRawData) {
     ForecastRecord oldRec = forecastRefEntry.get();
     if (oldRec == null) {
       double oldForecast =
-          processRawData(0, startTime, rawData, timeStamp);
+          processRawData(0, startTime, currRawData, timeStamp);
       forecastRefEntry.compareAndSet(null,
           new ForecastRecord(oldForecast, 0.0, startTime));
-      incorporateReading(timeStamp, rawData);
+      incorporateReading(timeStamp, currRawData);
       return;
     }
     while (!forecastRefEntry.compareAndSet(oldRec, oldRec.append(timeStamp,
-        rawData))) {
+        currRawData))) {
       oldRec = forecastRefEntry.get();
     }
-
   }
 
   public double getForecast() {
@@ -136,7 +173,7 @@ public class SimpleExponentialSmoothing {
     return DEFAULT_FORECAST;
   }
 
-  public boolean isDefaultForecast(double value) {
+  public boolean isDefaultForecast(final double value) {
     return value == DEFAULT_FORECAST;
   }
 
@@ -148,7 +185,7 @@ public class SimpleExponentialSmoothing {
     return DEFAULT_FORECAST;
   }
 
-  public boolean isErrorWithinBound(double bound) {
+  public boolean isErrorWithinBound(final double bound) {
     double squaredErr = getSSE();
     if (squaredErr < 0) {
       return false;
@@ -185,8 +222,8 @@ public class SimpleExponentialSmoothing {
     String res = "NULL";
     ForecastRecord rec = forecastRefEntry.get();
     if (rec != null) {
-      res =  "rec.index = " + rec.myIndex + ", forecast t: " + rec.timeStamp +
-          ", forecast: " + rec.forecast
+      res =  "rec.index = " + rec.myIndex + ", forecast t: " + rec.timeStamp
+          + ", forecast: " + rec.forecast
           + ", sample: " + rec.sample + ", raw: " + rec.rawData + ", error: "
           + rec.sseError + ", alpha: " + rec.alpha;
     }
diff --git 
a/hadoop-mapreduce-project/hadoop-mapreduce-client/hadoop-mapreduce-client-app/src/main/java/org/apache/hadoop/mapreduce/v2/app/speculate/forecast/package-info.java
 
b/hadoop-mapreduce-project/hadoop-mapreduce-client/hadoop-mapreduce-client-app/src/main/java/org/apache/hadoop/mapreduce/v2/app/speculate/forecast/package-info.java
new file mode 100644
index 0000000..52b8955
--- /dev/null
+++ 
b/hadoop-mapreduce-project/hadoop-mapreduce-client/hadoop-mapreduce-client-app/src/main/java/org/apache/hadoop/mapreduce/v2/app/speculate/forecast/package-info.java
@@ -0,0 +1,20 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+@InterfaceAudience.Private
+package org.apache.hadoop.mapreduce.v2.app.speculate.forecast;
+import org.apache.hadoop.classification.InterfaceAudience;
diff --git 
a/hadoop-mapreduce-project/hadoop-mapreduce-client/hadoop-mapreduce-client-app/src/test/java/org/apache/hadoop/mapreduce/v2/app/MRApp.java
 
b/hadoop-mapreduce-project/hadoop-mapreduce-client/hadoop-mapreduce-client-app/src/test/java/org/apache/hadoop/mapreduce/v2/app/MRApp.java
index a6e57ca..70ea18a 100644
--- 
a/hadoop-mapreduce-project/hadoop-mapreduce-client/hadoop-mapreduce-client-app/src/test/java/org/apache/hadoop/mapreduce/v2/app/MRApp.java
+++ 
b/hadoop-mapreduce-project/hadoop-mapreduce-client/hadoop-mapreduce-client-app/src/test/java/org/apache/hadoop/mapreduce/v2/app/MRApp.java
@@ -22,8 +22,11 @@ import java.io.File;
 import java.io.FileOutputStream;
 import java.io.IOException;
 import java.net.InetSocketAddress;
+import java.util.Arrays;
 import java.util.EnumSet;
 
+import java.util.List;
+import java.util.stream.Collectors;
 import org.apache.hadoop.conf.Configuration;
 import org.apache.hadoop.fs.FileContext;
 import org.apache.hadoop.fs.FileSystem;
@@ -372,18 +375,45 @@ public class MRApp extends MRAppMaster {
     TaskAttemptReport report = attempt.getReport();
     while (!finalState.equals(report.getTaskAttemptState()) &&
         timeoutSecs++ < 20) {
-      System.out.println("TaskAttempt State is : " + 
report.getTaskAttemptState() +
-          " Waiting for state : " + finalState +
-          "   progress : " + report.getProgress());
+      System.out.println(
+          "TaskAttempt " + attempt.getID().toString() + "  State is : "
+              + report.getTaskAttemptState()
+              + " Waiting for state : " + finalState
+              + "   progress : " + report.getProgress());
       report = attempt.getReport();
       Thread.sleep(500);
     }
-    System.out.println("TaskAttempt State is : " + 
report.getTaskAttemptState());
+    System.out.println("TaskAttempt State is : "
+        + report.getTaskAttemptState());
     Assert.assertEquals("TaskAttempt state is not correct (timedout)",
-        finalState, 
+        finalState,
         report.getTaskAttemptState());
   }
 
+  public void waitForState(TaskAttempt attempt,
+      TaskAttemptState...finalStates) throws Exception {
+    int timeoutSecs = 0;
+    TaskAttemptReport report = attempt.getReport();
+    List<TaskAttemptState> targetStates =  Arrays.asList(finalStates);
+    String statesValues = targetStates.stream().map(Object::toString).collect(
+        Collectors.joining(","));
+    while (!targetStates.contains(report.getTaskAttemptState()) &&
+        timeoutSecs++ < 20) {
+      System.out.println(
+          "TaskAttempt " + attempt.getID().toString() + "  State is : "
+              + report.getTaskAttemptState()
+              + " Waiting for states: " + statesValues
+              + ". curent state is : " + report.getTaskAttemptState()
+              + ".   progress : " + report.getProgress());
+      report = attempt.getReport();
+      Thread.sleep(500);
+    }
+    System.out.println("TaskAttempt State is : "
+        + report.getTaskAttemptState());
+    Assert.assertTrue("TaskAttempt state is not correct (timedout)",
+        targetStates.contains(report.getTaskAttemptState()));
+  }
+
   public void waitForState(Task task, TaskState finalState) throws Exception {
     int timeoutSecs = 0;
     TaskReport report = task.getReport();
@@ -396,7 +426,7 @@ public class MRApp extends MRAppMaster {
       Thread.sleep(500);
     }
     System.out.println("Task State is : " + report.getTaskState());
-    Assert.assertEquals("Task state is not correct (timedout)", finalState, 
+    Assert.assertEquals("Task state is not correct (timedout)", finalState,
         report.getTaskState());
   }
 
diff --git 
a/hadoop-mapreduce-project/hadoop-mapreduce-client/hadoop-mapreduce-client-jobclient/src/test/java/org/apache/hadoop/mapreduce/v2/TestSpeculativeExecutionWithMRApp.java
 
b/hadoop-mapreduce-project/hadoop-mapreduce-client/hadoop-mapreduce-client-jobclient/src/test/java/org/apache/hadoop/mapreduce/v2/TestSpeculativeExecutionWithMRApp.java
index 940f142..d4d432b 100644
--- 
a/hadoop-mapreduce-project/hadoop-mapreduce-client/hadoop-mapreduce-client-jobclient/src/test/java/org/apache/hadoop/mapreduce/v2/TestSpeculativeExecutionWithMRApp.java
+++ 
b/hadoop-mapreduce-project/hadoop-mapreduce-client/hadoop-mapreduce-client-jobclient/src/test/java/org/apache/hadoop/mapreduce/v2/TestSpeculativeExecutionWithMRApp.java
@@ -18,11 +18,14 @@
 
 package org.apache.hadoop.mapreduce.v2;
 
+import java.lang.annotation.Retention;
+import java.lang.annotation.RetentionPolicy;
 import java.util.Arrays;
 import java.util.Collection;
 import java.util.Iterator;
 import java.util.Map;
 import java.util.Random;
+import java.util.concurrent.atomic.AtomicInteger;
 import java.util.concurrent.atomic.AtomicReference;
 
 import org.apache.hadoop.mapreduce.MRJobConfig;
@@ -50,19 +53,94 @@ import org.apache.hadoop.yarn.event.EventHandler;
 import org.apache.hadoop.yarn.util.Clock;
 import org.apache.hadoop.yarn.util.ControlledClock;
 import org.apache.hadoop.yarn.util.SystemClock;
+import org.junit.Rule;
 import org.junit.Test;
 
 import com.google.common.base.Supplier;
+import org.junit.rules.TestRule;
+import org.junit.runner.Description;
 import org.junit.runner.RunWith;
 import org.junit.runners.Parameterized;
+import org.junit.runners.model.Statement;
 
+/**
+ * The type Test speculative execution with mr app.
+ * It test the speculation behavior given a list of estimator classes.
+ */
 @SuppressWarnings({ "unchecked", "rawtypes" })
 @RunWith(Parameterized.class)
 public class TestSpeculativeExecutionWithMRApp {
-
+  /** Number of times to re-try the failing tests. */
+  private static final int ASSERT_SPECULATIONS_COUNT_RETRIES = 3;
   private static final int NUM_MAPPERS = 5;
   private static final int NUM_REDUCERS = 0;
 
+  /**
+   * Speculation has non-deterministic behavior due to racing and timing. Use
+   * retry to verify that junit tests can pass.
+   */
+  @Retention(RetentionPolicy.RUNTIME)
+  public @interface Retry {}
+
+  /**
+   * The type Retry rule.
+   */
+  class RetryRule implements TestRule {
+
+    private AtomicInteger retryCount;
+
+    /**
+     * Instantiates a new Retry rule.
+     *
+     * @param retries the retries
+     */
+    RetryRule(int retries) {
+      super();
+      this.retryCount = new AtomicInteger(retries);
+    }
+
+    @Override
+    public Statement apply(final Statement base,
+        final Description description) {
+      return new Statement() {
+        @Override
+        public void evaluate() throws Throwable {
+          Throwable caughtThrowable = null;
+
+          while (retryCount.getAndDecrement() > 0) {
+            try {
+              base.evaluate();
+              return;
+            } catch (Throwable t) {
+              if (retryCount.get() > 0 &&
+                  description.getAnnotation(Retry.class) != null) {
+                caughtThrowable = t;
+                System.out.println(
+                    description.getDisplayName() +
+                        ": Failed, " +
+                        retryCount.toString() +
+                        " retries remain");
+              } else {
+                throw caughtThrowable;
+              }
+            }
+          }
+        }
+      };
+    }
+  }
+
+  /**
+   * The Rule.
+   */
+  @Rule
+  public RetryRule rule = new RetryRule(ASSERT_SPECULATIONS_COUNT_RETRIES);
+
+  /**
+   * Gets test parameters.
+   *
+   * @return the test parameters
+   */
   @Parameterized.Parameters(name = "{index}: TaskEstimator(EstimatorClass 
{0})")
   public static Collection<Object[]> getTestParameters() {
     return Arrays.asList(new Object[][] {
@@ -73,12 +151,23 @@ public class TestSpeculativeExecutionWithMRApp {
 
   private Class<? extends TaskRuntimeEstimator> estimatorClass;
 
+  /**
+   * Instantiates a new Test speculative execution with mr app.
+   *
+   * @param estimatorKlass the estimator klass
+   */
   public TestSpeculativeExecutionWithMRApp(
       Class<? extends TaskRuntimeEstimator>  estimatorKlass) {
     this.estimatorClass = estimatorKlass;
   }
 
-  @Test
+  /**
+   * Test speculate successful without update events.
+   *
+   * @throws Exception the exception
+   */
+  @Retry
+  @Test (timeout = 360000)
   public void testSpeculateSuccessfulWithoutUpdateEvents() throws Exception {
 
     Clock actualClock = SystemClock.getInstance();
@@ -128,7 +217,8 @@ public class TestSpeculativeExecutionWithMRApp {
             TaskAttemptEventType.TA_DONE));
           appEventHandler.handle(new TaskAttemptEvent(taskAttempt.getKey(),
             TaskAttemptEventType.TA_CONTAINER_COMPLETED));
-          app.waitForState(taskAttempt.getValue(), TaskAttemptState.SUCCEEDED);
+          app.waitForState(taskAttempt.getValue(), TaskAttemptState.SUCCEEDED,
+              TaskAttemptState.KILLED);
         }
       }
     }
@@ -150,8 +240,14 @@ public class TestSpeculativeExecutionWithMRApp {
     app.waitForState(Service.STATE.STOPPED);
   }
 
-  @Test
-  public void testSepculateSuccessfulWithUpdateEvents() throws Exception {
+  /**
+   * Test speculate successful with update events.
+   *
+   * @throws Exception the exception
+   */
+  @Retry
+  @Test (timeout = 360000)
+  public void testSpeculateSuccessfulWithUpdateEvents() throws Exception {
 
     Clock actualClock = SystemClock.getInstance();
     final ControlledClock clock = new ControlledClock(actualClock);
@@ -198,7 +294,8 @@ public class TestSpeculativeExecutionWithMRApp {
           appEventHandler.handle(new TaskAttemptEvent(taskAttempt.getKey(),
             TaskAttemptEventType.TA_CONTAINER_COMPLETED));
           numTasksToFinish--;
-          app.waitForState(taskAttempt.getValue(), TaskAttemptState.SUCCEEDED);
+          app.waitForState(taskAttempt.getValue(), TaskAttemptState.KILLED,
+              TaskAttemptState.SUCCEEDED);
         } else {
           // The last task is chosen for speculation
           TaskAttemptStatus status =
@@ -214,13 +311,12 @@ public class TestSpeculativeExecutionWithMRApp {
     }
 
     clock.setTime(System.currentTimeMillis() + 15000);
-    // give a chance to the speculator thread to run a scan before we proceed
-    // with updating events
-    Thread.yield();
+
     for (Map.Entry<TaskId, Task> task : tasks.entrySet()) {
       for (Map.Entry<TaskAttemptId, TaskAttempt> taskAttempt : task.getValue()
         .getAttempts().entrySet()) {
-        if (taskAttempt.getValue().getState() != TaskAttemptState.SUCCEEDED) {
+        if (!(taskAttempt.getValue().getState() == TaskAttemptState.SUCCEEDED
+            || taskAttempt.getValue().getState() == TaskAttemptState.KILLED)) {
           TaskAttemptStatus status =
               createTaskAttemptStatus(taskAttempt.getKey(), (float) 0.75,
                 TaskAttemptState.RUNNING);


---------------------------------------------------------------------
To unsubscribe, e-mail: common-commits-unsubscr...@hadoop.apache.org
For additional commands, e-mail: common-commits-h...@hadoop.apache.org

Reply via email to