Repository: spark
Updated Branches:
  refs/heads/branch-2.1 77d11df4f -> 46e6b6c0e


[SPARK-22897][CORE] Expose stageAttemptId in TaskContext

stageAttemptId added in TaskContext and corresponding construction modification

Added a new test in TaskContextSuite, two cases are tested:
1. Normal case without failure
2. Exception case with resubmitted stages

Link to [SPARK-22897](https://issues.apache.org/jira/browse/SPARK-22897)

Author: Xianjin YE <advance...@gmail.com>

Closes #20082 from advancedxy/SPARK-22897.

(cherry picked from commit a6fc300e91273230e7134ac6db95ccb4436c6f8f)
Signed-off-by: Marcelo Vanzin <van...@cloudera.com>


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/e5ccac21
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/e5ccac21
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/e5ccac21

Branch: refs/heads/branch-2.1
Commit: e5ccac21db69a5698e70d8fb993296fa854de132
Parents: 77d11df
Author: Xianjin YE <advance...@gmail.com>
Authored: Tue Jan 2 23:30:38 2018 +0800
Committer: Marcelo Vanzin <van...@cloudera.com>
Committed: Thu Jun 21 13:47:27 2018 -0700

----------------------------------------------------------------------
 .../scala/org/apache/spark/TaskContext.scala    |  9 +++++-
 .../org/apache/spark/TaskContextImpl.scala      |  5 ++--
 .../scala/org/apache/spark/scheduler/Task.scala |  1 +
 .../spark/JavaTaskContextCompileCheck.java      |  2 ++
 .../scala/org/apache/spark/ShuffleSuite.scala   |  6 ++--
 .../spark/memory/MemoryTestingUtils.scala       |  1 +
 .../spark/scheduler/TaskContextSuite.scala      | 29 ++++++++++++++++++--
 .../spark/storage/BlockInfoManagerSuite.scala   |  2 +-
 project/MimaExcludes.scala                      |  2 ++
 .../UnsafeFixedWidthAggregationMapSuite.scala   |  1 +
 .../execution/UnsafeKVExternalSorterSuite.scala |  1 +
 .../execution/UnsafeRowSerializerSuite.scala    |  2 +-
 12 files changed, 51 insertions(+), 10 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/e5ccac21/core/src/main/scala/org/apache/spark/TaskContext.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/TaskContext.scala 
b/core/src/main/scala/org/apache/spark/TaskContext.scala
index 0fd777e..d7d67db 100644
--- a/core/src/main/scala/org/apache/spark/TaskContext.scala
+++ b/core/src/main/scala/org/apache/spark/TaskContext.scala
@@ -65,7 +65,7 @@ object TaskContext {
    * An empty task context that does not represent an actual task.  This is 
only used in tests.
    */
   private[spark] def empty(): TaskContextImpl = {
-    new TaskContextImpl(0, 0, 0, 0, null, new Properties, null)
+    new TaskContextImpl(0, 0, 0, 0, 0, null, new Properties, null)
   }
 }
 
@@ -146,6 +146,13 @@ abstract class TaskContext extends Serializable {
   def stageId(): Int
 
   /**
+   * How many times the stage that this task belongs to has been attempted. 
The first stage attempt
+   * will be assigned stageAttemptNumber = 0, and subsequent attempts will 
have increasing attempt
+   * numbers.
+   */
+  def stageAttemptNumber(): Int
+
+  /**
    * The ID of the RDD partition that is computed by this task.
    */
   def partitionId(): Int

http://git-wip-us.apache.org/repos/asf/spark/blob/e5ccac21/core/src/main/scala/org/apache/spark/TaskContextImpl.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/TaskContextImpl.scala 
b/core/src/main/scala/org/apache/spark/TaskContextImpl.scala
index c904e08..8159f1b 100644
--- a/core/src/main/scala/org/apache/spark/TaskContextImpl.scala
+++ b/core/src/main/scala/org/apache/spark/TaskContextImpl.scala
@@ -29,8 +29,9 @@ import org.apache.spark.metrics.source.Source
 import org.apache.spark.util._
 
 private[spark] class TaskContextImpl(
-    val stageId: Int,
-    val partitionId: Int,
+    override val stageId: Int,
+    override val stageAttemptNumber: Int,
+    override val partitionId: Int,
     override val taskAttemptId: Long,
     override val attemptNumber: Int,
     override val taskMemoryManager: TaskMemoryManager,

http://git-wip-us.apache.org/repos/asf/spark/blob/e5ccac21/core/src/main/scala/org/apache/spark/scheduler/Task.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/scheduler/Task.scala 
b/core/src/main/scala/org/apache/spark/scheduler/Task.scala
index 112b08f..cabe6a7 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/Task.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/Task.scala
@@ -78,6 +78,7 @@ private[spark] abstract class Task[T](
     SparkEnv.get.blockManager.registerTask(taskAttemptId)
     context = new TaskContextImpl(
       stageId,
+      stageAttemptId, // stageAttemptId and stageAttemptNumber are 
semantically equal
       partitionId,
       taskAttemptId,
       attemptNumber,

http://git-wip-us.apache.org/repos/asf/spark/blob/e5ccac21/core/src/test/java/test/org/apache/spark/JavaTaskContextCompileCheck.java
----------------------------------------------------------------------
diff --git 
a/core/src/test/java/test/org/apache/spark/JavaTaskContextCompileCheck.java 
b/core/src/test/java/test/org/apache/spark/JavaTaskContextCompileCheck.java
index 94f5805..f8e233a 100644
--- a/core/src/test/java/test/org/apache/spark/JavaTaskContextCompileCheck.java
+++ b/core/src/test/java/test/org/apache/spark/JavaTaskContextCompileCheck.java
@@ -38,6 +38,7 @@ public class JavaTaskContextCompileCheck {
     tc.attemptNumber();
     tc.partitionId();
     tc.stageId();
+    tc.stageAttemptNumber();
     tc.taskAttemptId();
   }
 
@@ -51,6 +52,7 @@ public class JavaTaskContextCompileCheck {
       context.isCompleted();
       context.isInterrupted();
       context.stageId();
+      context.stageAttemptNumber();
       context.partitionId();
       context.addTaskCompletionListener(this);
     }

http://git-wip-us.apache.org/repos/asf/spark/blob/e5ccac21/core/src/test/scala/org/apache/spark/ShuffleSuite.scala
----------------------------------------------------------------------
diff --git a/core/src/test/scala/org/apache/spark/ShuffleSuite.scala 
b/core/src/test/scala/org/apache/spark/ShuffleSuite.scala
index a854f5bb..f3f891a 100644
--- a/core/src/test/scala/org/apache/spark/ShuffleSuite.scala
+++ b/core/src/test/scala/org/apache/spark/ShuffleSuite.scala
@@ -336,14 +336,14 @@ abstract class ShuffleSuite extends SparkFunSuite with 
Matchers with LocalSparkC
 
     // first attempt -- its successful
     val writer1 = manager.getWriter[Int, Int](shuffleHandle, 0,
-      new TaskContextImpl(0, 0, 0L, 0, taskMemoryManager, new Properties, 
metricsSystem))
+      new TaskContextImpl(0, 0, 0, 0L, 0, taskMemoryManager, new Properties, 
metricsSystem))
     val data1 = (1 to 10).map { x => x -> x}
 
     // second attempt -- also successful.  We'll write out different data,
     // just to simulate the fact that the records may get written differently
     // depending on what gets spilled, what gets combined, etc.
     val writer2 = manager.getWriter[Int, Int](shuffleHandle, 0,
-      new TaskContextImpl(0, 0, 1L, 0, taskMemoryManager, new Properties, 
metricsSystem))
+      new TaskContextImpl(0, 0, 0, 1L, 0, taskMemoryManager, new Properties, 
metricsSystem))
     val data2 = (11 to 20).map { x => x -> x}
 
     // interleave writes of both attempts -- we want to test that both 
attempts can occur
@@ -371,7 +371,7 @@ abstract class ShuffleSuite extends SparkFunSuite with 
Matchers with LocalSparkC
     }
 
     val reader = manager.getReader[Int, Int](shuffleHandle, 0, 1,
-      new TaskContextImpl(1, 0, 2L, 0, taskMemoryManager, new Properties, 
metricsSystem))
+      new TaskContextImpl(1, 0, 0, 2L, 0, taskMemoryManager, new Properties, 
metricsSystem))
     val readData = reader.read().toIndexedSeq
     assert(readData === data1.toIndexedSeq || readData === data2.toIndexedSeq)
 

http://git-wip-us.apache.org/repos/asf/spark/blob/e5ccac21/core/src/test/scala/org/apache/spark/memory/MemoryTestingUtils.scala
----------------------------------------------------------------------
diff --git 
a/core/src/test/scala/org/apache/spark/memory/MemoryTestingUtils.scala 
b/core/src/test/scala/org/apache/spark/memory/MemoryTestingUtils.scala
index 362cd86..dcf89e4 100644
--- a/core/src/test/scala/org/apache/spark/memory/MemoryTestingUtils.scala
+++ b/core/src/test/scala/org/apache/spark/memory/MemoryTestingUtils.scala
@@ -29,6 +29,7 @@ object MemoryTestingUtils {
     val taskMemoryManager = new TaskMemoryManager(env.memoryManager, 0)
     new TaskContextImpl(
       stageId = 0,
+      stageAttemptNumber = 0,
       partitionId = 0,
       taskAttemptId = 0,
       attemptNumber = 0,

http://git-wip-us.apache.org/repos/asf/spark/blob/e5ccac21/core/src/test/scala/org/apache/spark/scheduler/TaskContextSuite.scala
----------------------------------------------------------------------
diff --git 
a/core/src/test/scala/org/apache/spark/scheduler/TaskContextSuite.scala 
b/core/src/test/scala/org/apache/spark/scheduler/TaskContextSuite.scala
index 9eda79a..0b6a45a 100644
--- a/core/src/test/scala/org/apache/spark/scheduler/TaskContextSuite.scala
+++ b/core/src/test/scala/org/apache/spark/scheduler/TaskContextSuite.scala
@@ -29,6 +29,7 @@ import org.apache.spark.memory.TaskMemoryManager
 import org.apache.spark.metrics.source.JvmSource
 import org.apache.spark.network.util.JavaUtils
 import org.apache.spark.rdd.RDD
+import org.apache.spark.shuffle.FetchFailedException
 import org.apache.spark.util._
 
 class TaskContextSuite extends SparkFunSuite with BeforeAndAfter with 
LocalSparkContext {
@@ -143,6 +144,30 @@ class TaskContextSuite extends SparkFunSuite with 
BeforeAndAfter with LocalSpark
     assert(attemptIdsWithFailedTask.toSet === Set(0, 1))
   }
 
+  test("TaskContext.stageAttemptNumber getter") {
+    sc = new SparkContext("local[1,2]", "test")
+
+    // Check stageAttemptNumbers are 0 for initial stage
+    val stageAttemptNumbers = sc.parallelize(Seq(1, 2), 2).mapPartitions { _ =>
+      Seq(TaskContext.get().stageAttemptNumber()).iterator
+    }.collect()
+    assert(stageAttemptNumbers.toSet === Set(0))
+
+    // Check stageAttemptNumbers that are resubmitted when tasks have 
FetchFailedException
+    val stageAttemptNumbersWithFailedStage =
+      sc.parallelize(Seq(1, 2, 3, 4), 4).repartition(1).mapPartitions { _ =>
+      val stageAttemptNumber = TaskContext.get().stageAttemptNumber()
+      if (stageAttemptNumber < 2) {
+        // Throw FetchFailedException to explicitly trigger stage 
resubmission. A normal exception
+        // will only trigger task resubmission in the same stage.
+        throw new FetchFailedException(null, 0, 0, 0, "Fake")
+      }
+      Seq(stageAttemptNumber).iterator
+    }.collect()
+
+    assert(stageAttemptNumbersWithFailedStage.toSet === Set(2))
+  }
+
   test("accumulators are updated on exception failures") {
     // This means use 1 core and 4 max task failures
     sc = new SparkContext("local[1,4]", "test")
@@ -175,7 +200,7 @@ class TaskContextSuite extends SparkFunSuite with 
BeforeAndAfter with LocalSpark
     // accumulator updates from it.
     val taskMetrics = TaskMetrics.empty
     val task = new Task[Int](0, 0, 0) {
-      context = new TaskContextImpl(0, 0, 0L, 0,
+      context = new TaskContextImpl(0, 0, 0, 0L, 0,
         new TaskMemoryManager(SparkEnv.get.memoryManager, 0L),
         new Properties,
         SparkEnv.get.metricsSystem,
@@ -198,7 +223,7 @@ class TaskContextSuite extends SparkFunSuite with 
BeforeAndAfter with LocalSpark
     // accumulator updates from it.
     val taskMetrics = TaskMetrics.empty
     val task = new Task[Int](0, 0, 0) {
-      context = new TaskContextImpl(0, 0, 0L, 0,
+      context = new TaskContextImpl(0, 0, 0, 0L, 0,
         new TaskMemoryManager(SparkEnv.get.memoryManager, 0L),
         new Properties,
         SparkEnv.get.metricsSystem,

http://git-wip-us.apache.org/repos/asf/spark/blob/e5ccac21/core/src/test/scala/org/apache/spark/storage/BlockInfoManagerSuite.scala
----------------------------------------------------------------------
diff --git 
a/core/src/test/scala/org/apache/spark/storage/BlockInfoManagerSuite.scala 
b/core/src/test/scala/org/apache/spark/storage/BlockInfoManagerSuite.scala
index 1b32580..285e453 100644
--- a/core/src/test/scala/org/apache/spark/storage/BlockInfoManagerSuite.scala
+++ b/core/src/test/scala/org/apache/spark/storage/BlockInfoManagerSuite.scala
@@ -62,7 +62,7 @@ class BlockInfoManagerSuite extends SparkFunSuite with 
BeforeAndAfterEach {
   private def withTaskId[T](taskAttemptId: Long)(block: => T): T = {
     try {
       TaskContext.setTaskContext(
-        new TaskContextImpl(0, 0, taskAttemptId, 0, null, new Properties, 
null))
+        new TaskContextImpl(0, 0, 0, taskAttemptId, 0, null, new Properties, 
null))
       block
     } finally {
       TaskContext.unset()

http://git-wip-us.apache.org/repos/asf/spark/blob/e5ccac21/project/MimaExcludes.scala
----------------------------------------------------------------------
diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala
index d872093..13c57df 100644
--- a/project/MimaExcludes.scala
+++ b/project/MimaExcludes.scala
@@ -37,6 +37,8 @@ object MimaExcludes {
   // Exclude rules for 2.1.x
   lazy val v21excludes = v20excludes ++ {
     Seq(
+      // [SPARK-22897] Expose stageAttemptId in TaskContext
+      
ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.TaskContext.stageAttemptNumber"),
       // [SPARK-19652][UI] Do auth checks for REST API access.
       
ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.deploy.history.HistoryServer.withSparkUI"),
       
ProblemFilters.exclude[IncompatibleTemplateDefProblem]("org.apache.spark.status.api.v1.UIRootFromServletContext"),

http://git-wip-us.apache.org/repos/asf/spark/blob/e5ccac21/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMapSuite.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMapSuite.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMapSuite.scala
index 6cf18de..6c222a0 100644
--- 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMapSuite.scala
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMapSuite.scala
@@ -69,6 +69,7 @@ class UnsafeFixedWidthAggregationMapSuite
 
       TaskContext.setTaskContext(new TaskContextImpl(
         stageId = 0,
+        stageAttemptNumber = 0,
         partitionId = 0,
         taskAttemptId = Random.nextInt(10000),
         attemptNumber = 0,

http://git-wip-us.apache.org/repos/asf/spark/blob/e5ccac21/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeKVExternalSorterSuite.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeKVExternalSorterSuite.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeKVExternalSorterSuite.scala
index 3d869c7..5ad7cad 100644
--- 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeKVExternalSorterSuite.scala
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeKVExternalSorterSuite.scala
@@ -116,6 +116,7 @@ class UnsafeKVExternalSorterSuite extends SparkFunSuite 
with SharedSQLContext {
     val taskMemMgr = new TaskMemoryManager(memoryManager, 0)
     TaskContext.setTaskContext(new TaskContextImpl(
       stageId = 0,
+      stageAttemptNumber = 0,
       partitionId = 0,
       taskAttemptId = 98456,
       attemptNumber = 0,

http://git-wip-us.apache.org/repos/asf/spark/blob/e5ccac21/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeRowSerializerSuite.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeRowSerializerSuite.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeRowSerializerSuite.scala
index 53105e0..c3ecf52 100644
--- 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeRowSerializerSuite.scala
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeRowSerializerSuite.scala
@@ -114,7 +114,7 @@ class UnsafeRowSerializerSuite extends SparkFunSuite with 
LocalSparkContext {
         (i, converter(Row(i)))
       }
       val taskMemoryManager = new TaskMemoryManager(sc.env.memoryManager, 0)
-      val taskContext = new TaskContextImpl(0, 0, 0, 0, taskMemoryManager, new 
Properties, null)
+      val taskContext = new TaskContextImpl(0, 0, 0, 0, 0, taskMemoryManager, 
new Properties, null)
 
       val sorter = new ExternalSorter[Int, UnsafeRow, UnsafeRow](
         taskContext,


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to