Repository: spark
Updated Branches:
  refs/heads/master e0c090f22 -> a6fc300e9


[SPARK-22897][CORE] Expose stageAttemptId in TaskContext

## What changes were proposed in this pull request?
stageAttemptId added in TaskContext and corresponding construction modification

## How was this patch tested?
Added a new test in TaskContextSuite, two cases are tested:
1. Normal case without failure
2. Exception case with resubmitted stages

Link to [SPARK-22897](https://issues.apache.org/jira/browse/SPARK-22897)

Author: Xianjin YE <advance...@gmail.com>

Closes #20082 from advancedxy/SPARK-22897.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/a6fc300e
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/a6fc300e
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/a6fc300e

Branch: refs/heads/master
Commit: a6fc300e91273230e7134ac6db95ccb4436c6f8f
Parents: e0c090f
Author: Xianjin YE <advance...@gmail.com>
Authored: Tue Jan 2 23:30:38 2018 +0800
Committer: Wenchen Fan <wenc...@databricks.com>
Committed: Tue Jan 2 23:30:38 2018 +0800

----------------------------------------------------------------------
 .../scala/org/apache/spark/TaskContext.scala    |  9 +++++-
 .../org/apache/spark/TaskContextImpl.scala      |  5 ++--
 .../scala/org/apache/spark/scheduler/Task.scala |  1 +
 .../spark/JavaTaskContextCompileCheck.java      |  2 ++
 .../scala/org/apache/spark/ShuffleSuite.scala   |  6 ++--
 .../spark/memory/MemoryTestingUtils.scala       |  1 +
 .../spark/scheduler/TaskContextSuite.scala      | 29 ++++++++++++++++++--
 .../spark/storage/BlockInfoManagerSuite.scala   |  2 +-
 project/MimaExcludes.scala                      |  3 ++
 .../UnsafeFixedWidthAggregationMapSuite.scala   |  1 +
 .../execution/UnsafeKVExternalSorterSuite.scala |  1 +
 .../execution/UnsafeRowSerializerSuite.scala    |  2 +-
 .../SortBasedAggregationStoreSuite.scala        |  3 +-
 13 files changed, 54 insertions(+), 11 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/a6fc300e/core/src/main/scala/org/apache/spark/TaskContext.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/TaskContext.scala 
b/core/src/main/scala/org/apache/spark/TaskContext.scala
index 0b87cd5..6973974 100644
--- a/core/src/main/scala/org/apache/spark/TaskContext.scala
+++ b/core/src/main/scala/org/apache/spark/TaskContext.scala
@@ -66,7 +66,7 @@ object TaskContext {
    * An empty task context that does not represent an actual task.  This is 
only used in tests.
    */
   private[spark] def empty(): TaskContextImpl = {
-    new TaskContextImpl(0, 0, 0, 0, null, new Properties, null)
+    new TaskContextImpl(0, 0, 0, 0, 0, null, new Properties, null)
   }
 }
 
@@ -151,6 +151,13 @@ abstract class TaskContext extends Serializable {
   def stageId(): Int
 
   /**
+   * How many times the stage that this task belongs to has been attempted. 
The first stage attempt
+   * will be assigned stageAttemptNumber = 0, and subsequent attempts will 
have increasing attempt
+   * numbers.
+   */
+  def stageAttemptNumber(): Int
+
+  /**
    * The ID of the RDD partition that is computed by this task.
    */
   def partitionId(): Int

http://git-wip-us.apache.org/repos/asf/spark/blob/a6fc300e/core/src/main/scala/org/apache/spark/TaskContextImpl.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/TaskContextImpl.scala 
b/core/src/main/scala/org/apache/spark/TaskContextImpl.scala
index 01d8973..cccd3ea 100644
--- a/core/src/main/scala/org/apache/spark/TaskContextImpl.scala
+++ b/core/src/main/scala/org/apache/spark/TaskContextImpl.scala
@@ -41,8 +41,9 @@ import org.apache.spark.util._
  * `TaskMetrics` & `MetricsSystem` objects are not thread safe.
  */
 private[spark] class TaskContextImpl(
-    val stageId: Int,
-    val partitionId: Int,
+    override val stageId: Int,
+    override val stageAttemptNumber: Int,
+    override val partitionId: Int,
     override val taskAttemptId: Long,
     override val attemptNumber: Int,
     override val taskMemoryManager: TaskMemoryManager,

http://git-wip-us.apache.org/repos/asf/spark/blob/a6fc300e/core/src/main/scala/org/apache/spark/scheduler/Task.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/scheduler/Task.scala 
b/core/src/main/scala/org/apache/spark/scheduler/Task.scala
index 7767ef1..f536fc2 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/Task.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/Task.scala
@@ -79,6 +79,7 @@ private[spark] abstract class Task[T](
     SparkEnv.get.blockManager.registerTask(taskAttemptId)
     context = new TaskContextImpl(
       stageId,
+      stageAttemptId, // stageAttemptId and stageAttemptNumber are 
semantically equal
       partitionId,
       taskAttemptId,
       attemptNumber,

http://git-wip-us.apache.org/repos/asf/spark/blob/a6fc300e/core/src/test/java/test/org/apache/spark/JavaTaskContextCompileCheck.java
----------------------------------------------------------------------
diff --git 
a/core/src/test/java/test/org/apache/spark/JavaTaskContextCompileCheck.java 
b/core/src/test/java/test/org/apache/spark/JavaTaskContextCompileCheck.java
index 94f5805..f8e233a 100644
--- a/core/src/test/java/test/org/apache/spark/JavaTaskContextCompileCheck.java
+++ b/core/src/test/java/test/org/apache/spark/JavaTaskContextCompileCheck.java
@@ -38,6 +38,7 @@ public class JavaTaskContextCompileCheck {
     tc.attemptNumber();
     tc.partitionId();
     tc.stageId();
+    tc.stageAttemptNumber();
     tc.taskAttemptId();
   }
 
@@ -51,6 +52,7 @@ public class JavaTaskContextCompileCheck {
       context.isCompleted();
       context.isInterrupted();
       context.stageId();
+      context.stageAttemptNumber();
       context.partitionId();
       context.addTaskCompletionListener(this);
     }

http://git-wip-us.apache.org/repos/asf/spark/blob/a6fc300e/core/src/test/scala/org/apache/spark/ShuffleSuite.scala
----------------------------------------------------------------------
diff --git a/core/src/test/scala/org/apache/spark/ShuffleSuite.scala 
b/core/src/test/scala/org/apache/spark/ShuffleSuite.scala
index 3931d53..ced5a06 100644
--- a/core/src/test/scala/org/apache/spark/ShuffleSuite.scala
+++ b/core/src/test/scala/org/apache/spark/ShuffleSuite.scala
@@ -363,14 +363,14 @@ abstract class ShuffleSuite extends SparkFunSuite with 
Matchers with LocalSparkC
 
     // first attempt -- its successful
     val writer1 = manager.getWriter[Int, Int](shuffleHandle, 0,
-      new TaskContextImpl(0, 0, 0L, 0, taskMemoryManager, new Properties, 
metricsSystem))
+      new TaskContextImpl(0, 0, 0, 0L, 0, taskMemoryManager, new Properties, 
metricsSystem))
     val data1 = (1 to 10).map { x => x -> x}
 
     // second attempt -- also successful.  We'll write out different data,
     // just to simulate the fact that the records may get written differently
     // depending on what gets spilled, what gets combined, etc.
     val writer2 = manager.getWriter[Int, Int](shuffleHandle, 0,
-      new TaskContextImpl(0, 0, 1L, 0, taskMemoryManager, new Properties, 
metricsSystem))
+      new TaskContextImpl(0, 0, 0, 1L, 0, taskMemoryManager, new Properties, 
metricsSystem))
     val data2 = (11 to 20).map { x => x -> x}
 
     // interleave writes of both attempts -- we want to test that both 
attempts can occur
@@ -398,7 +398,7 @@ abstract class ShuffleSuite extends SparkFunSuite with 
Matchers with LocalSparkC
     }
 
     val reader = manager.getReader[Int, Int](shuffleHandle, 0, 1,
-      new TaskContextImpl(1, 0, 2L, 0, taskMemoryManager, new Properties, 
metricsSystem))
+      new TaskContextImpl(1, 0, 0, 2L, 0, taskMemoryManager, new Properties, 
metricsSystem))
     val readData = reader.read().toIndexedSeq
     assert(readData === data1.toIndexedSeq || readData === data2.toIndexedSeq)
 

http://git-wip-us.apache.org/repos/asf/spark/blob/a6fc300e/core/src/test/scala/org/apache/spark/memory/MemoryTestingUtils.scala
----------------------------------------------------------------------
diff --git 
a/core/src/test/scala/org/apache/spark/memory/MemoryTestingUtils.scala 
b/core/src/test/scala/org/apache/spark/memory/MemoryTestingUtils.scala
index 362cd86..dcf89e4 100644
--- a/core/src/test/scala/org/apache/spark/memory/MemoryTestingUtils.scala
+++ b/core/src/test/scala/org/apache/spark/memory/MemoryTestingUtils.scala
@@ -29,6 +29,7 @@ object MemoryTestingUtils {
     val taskMemoryManager = new TaskMemoryManager(env.memoryManager, 0)
     new TaskContextImpl(
       stageId = 0,
+      stageAttemptNumber = 0,
       partitionId = 0,
       taskAttemptId = 0,
       attemptNumber = 0,

http://git-wip-us.apache.org/repos/asf/spark/blob/a6fc300e/core/src/test/scala/org/apache/spark/scheduler/TaskContextSuite.scala
----------------------------------------------------------------------
diff --git 
a/core/src/test/scala/org/apache/spark/scheduler/TaskContextSuite.scala 
b/core/src/test/scala/org/apache/spark/scheduler/TaskContextSuite.scala
index a1d9085..aa9c36c 100644
--- a/core/src/test/scala/org/apache/spark/scheduler/TaskContextSuite.scala
+++ b/core/src/test/scala/org/apache/spark/scheduler/TaskContextSuite.scala
@@ -29,6 +29,7 @@ import org.apache.spark.memory.TaskMemoryManager
 import org.apache.spark.metrics.source.JvmSource
 import org.apache.spark.network.util.JavaUtils
 import org.apache.spark.rdd.RDD
+import org.apache.spark.shuffle.FetchFailedException
 import org.apache.spark.util._
 
 class TaskContextSuite extends SparkFunSuite with BeforeAndAfter with 
LocalSparkContext {
@@ -158,6 +159,30 @@ class TaskContextSuite extends SparkFunSuite with 
BeforeAndAfter with LocalSpark
     assert(attemptIdsWithFailedTask.toSet === Set(0, 1))
   }
 
+  test("TaskContext.stageAttemptNumber getter") {
+    sc = new SparkContext("local[1,2]", "test")
+
+    // Check stageAttemptNumbers are 0 for initial stage
+    val stageAttemptNumbers = sc.parallelize(Seq(1, 2), 2).mapPartitions { _ =>
+      Seq(TaskContext.get().stageAttemptNumber()).iterator
+    }.collect()
+    assert(stageAttemptNumbers.toSet === Set(0))
+
+    // Check stageAttemptNumbers that are resubmitted when tasks have 
FetchFailedException
+    val stageAttemptNumbersWithFailedStage =
+      sc.parallelize(Seq(1, 2, 3, 4), 4).repartition(1).mapPartitions { _ =>
+      val stageAttemptNumber = TaskContext.get().stageAttemptNumber()
+      if (stageAttemptNumber < 2) {
+        // Throw FetchFailedException to explicitly trigger stage 
resubmission. A normal exception
+        // will only trigger task resubmission in the same stage.
+        throw new FetchFailedException(null, 0, 0, 0, "Fake")
+      }
+      Seq(stageAttemptNumber).iterator
+    }.collect()
+
+    assert(stageAttemptNumbersWithFailedStage.toSet === Set(2))
+  }
+
   test("accumulators are updated on exception failures") {
     // This means use 1 core and 4 max task failures
     sc = new SparkContext("local[1,4]", "test")
@@ -190,7 +215,7 @@ class TaskContextSuite extends SparkFunSuite with 
BeforeAndAfter with LocalSpark
     // accumulator updates from it.
     val taskMetrics = TaskMetrics.empty
     val task = new Task[Int](0, 0, 0) {
-      context = new TaskContextImpl(0, 0, 0L, 0,
+      context = new TaskContextImpl(0, 0, 0, 0L, 0,
         new TaskMemoryManager(SparkEnv.get.memoryManager, 0L),
         new Properties,
         SparkEnv.get.metricsSystem,
@@ -213,7 +238,7 @@ class TaskContextSuite extends SparkFunSuite with 
BeforeAndAfter with LocalSpark
     // accumulator updates from it.
     val taskMetrics = TaskMetrics.registered
     val task = new Task[Int](0, 0, 0) {
-      context = new TaskContextImpl(0, 0, 0L, 0,
+      context = new TaskContextImpl(0, 0, 0, 0L, 0,
         new TaskMemoryManager(SparkEnv.get.memoryManager, 0L),
         new Properties,
         SparkEnv.get.metricsSystem,

http://git-wip-us.apache.org/repos/asf/spark/blob/a6fc300e/core/src/test/scala/org/apache/spark/storage/BlockInfoManagerSuite.scala
----------------------------------------------------------------------
diff --git 
a/core/src/test/scala/org/apache/spark/storage/BlockInfoManagerSuite.scala 
b/core/src/test/scala/org/apache/spark/storage/BlockInfoManagerSuite.scala
index 917db76..9c0699b 100644
--- a/core/src/test/scala/org/apache/spark/storage/BlockInfoManagerSuite.scala
+++ b/core/src/test/scala/org/apache/spark/storage/BlockInfoManagerSuite.scala
@@ -62,7 +62,7 @@ class BlockInfoManagerSuite extends SparkFunSuite with 
BeforeAndAfterEach {
   private def withTaskId[T](taskAttemptId: Long)(block: => T): T = {
     try {
       TaskContext.setTaskContext(
-        new TaskContextImpl(0, 0, taskAttemptId, 0, null, new Properties, 
null))
+        new TaskContextImpl(0, 0, 0, taskAttemptId, 0, null, new Properties, 
null))
       block
     } finally {
       TaskContext.unset()

http://git-wip-us.apache.org/repos/asf/spark/blob/a6fc300e/project/MimaExcludes.scala
----------------------------------------------------------------------
diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala
index 81584af..3b452f3 100644
--- a/project/MimaExcludes.scala
+++ b/project/MimaExcludes.scala
@@ -36,6 +36,9 @@ object MimaExcludes {
 
   // Exclude rules for 2.3.x
   lazy val v23excludes = v22excludes ++ Seq(
+    // [SPARK-22897] Expose stageAttemptId in TaskContext
+    
ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.TaskContext.stageAttemptNumber"),
+
     // SPARK-22789: Map-only continuous processing execution
     
ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.sql.streaming.StreamingQueryManager.startQuery$default$8"),
     
ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.sql.streaming.StreamingQueryManager.startQuery$default$6"),

http://git-wip-us.apache.org/repos/asf/spark/blob/a6fc300e/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMapSuite.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMapSuite.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMapSuite.scala
index 232c1be..3e31d22 100644
--- 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMapSuite.scala
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMapSuite.scala
@@ -70,6 +70,7 @@ class UnsafeFixedWidthAggregationMapSuite
 
       TaskContext.setTaskContext(new TaskContextImpl(
         stageId = 0,
+        stageAttemptNumber = 0,
         partitionId = 0,
         taskAttemptId = Random.nextInt(10000),
         attemptNumber = 0,

http://git-wip-us.apache.org/repos/asf/spark/blob/a6fc300e/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeKVExternalSorterSuite.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeKVExternalSorterSuite.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeKVExternalSorterSuite.scala
index 604502f..6af9f8b 100644
--- 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeKVExternalSorterSuite.scala
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeKVExternalSorterSuite.scala
@@ -116,6 +116,7 @@ class UnsafeKVExternalSorterSuite extends SparkFunSuite 
with SharedSQLContext {
     val taskMemMgr = new TaskMemoryManager(memoryManager, 0)
     TaskContext.setTaskContext(new TaskContextImpl(
       stageId = 0,
+      stageAttemptNumber = 0,
       partitionId = 0,
       taskAttemptId = 98456,
       attemptNumber = 0,

http://git-wip-us.apache.org/repos/asf/spark/blob/a6fc300e/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeRowSerializerSuite.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeRowSerializerSuite.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeRowSerializerSuite.scala
index dff88ce..a3ae938 100644
--- 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeRowSerializerSuite.scala
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeRowSerializerSuite.scala
@@ -114,7 +114,7 @@ class UnsafeRowSerializerSuite extends SparkFunSuite with 
LocalSparkContext {
         (i, converter(Row(i)))
       }
       val taskMemoryManager = new TaskMemoryManager(sc.env.memoryManager, 0)
-      val taskContext = new TaskContextImpl(0, 0, 0, 0, taskMemoryManager, new 
Properties, null)
+      val taskContext = new TaskContextImpl(0, 0, 0, 0, 0, taskMemoryManager, 
new Properties, null)
 
       val sorter = new ExternalSorter[Int, UnsafeRow, UnsafeRow](
         taskContext,

http://git-wip-us.apache.org/repos/asf/spark/blob/a6fc300e/sql/core/src/test/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregationStoreSuite.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregationStoreSuite.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregationStoreSuite.scala
index 10f1ee2..3fad7df 100644
--- 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregationStoreSuite.scala
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregationStoreSuite.scala
@@ -35,7 +35,8 @@ class SortBasedAggregationStoreSuite  extends SparkFunSuite 
with LocalSparkConte
     val conf = new SparkConf()
     sc = new SparkContext("local[2, 4]", "test", conf)
     val taskManager = new TaskMemoryManager(new TestMemoryManager(conf), 0)
-    TaskContext.setTaskContext(new TaskContextImpl(0, 0, 0, 0, taskManager, 
new Properties, null))
+    TaskContext.setTaskContext(
+      new TaskContextImpl(0, 0, 0, 0, 0, taskManager, new Properties, null))
   }
 
   override def afterAll(): Unit = TaskContext.unset()


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to