Repository: spark
Updated Branches:
  refs/heads/master 7e24249af -> 1cb377007


[SPARK-4879] Use driver to coordinate Hadoop output committing for speculative 
tasks

Previously, SparkHadoopWriter always committed its tasks without question. The 
problem is that when speculation is enabled sometimes this can result in 
multiple tasks committing their output to the same file. Even though an 
HDFS-writing task may be re-launched due to speculation, the original task is 
not killed and may eventually commit as well.

This can cause strange race conditions where multiple tasks that commit 
interfere with each other, with the result being that some partition files are 
actually lost entirely. For more context on these kinds of scenarios, see 
SPARK-4879.

In Hadoop MapReduce jobs, the application master is a central coordinator that 
authorizes whether or not any given task can commit. Before a task commits its 
output, it queries the application master as to whether or not such a commit is 
safe, and the application master does bookkeeping as tasks are requesting 
commits. Duplicate tasks that would write to files that were already written to 
from other tasks are prohibited from committing.

This patch emulates that functionality - the crucial missing component was a 
central arbitrator, which is now a module called the OutputCommitCoordinator. 
The coordinator lives on the driver and the executors can obtain a reference to 
this actor and request its permission to commit. As tasks commit and are 
reported as completed successfully or unsuccessfully by the DAGScheduler, the 
commit coordinator is informed of the task completion events as well to update 
its internal state.

Future work includes more rigorous unit testing and extra optimizations should 
this patch cause a performance regression. It is unclear what the overall cost 
of communicating back to the driver on every hadoop-committing task will be. 
It's also important for those hitting this issue to backport this onto previous 
version of Spark because the bug has serious consequences, that is, data is 
lost.

Currently, the OutputCommitCoordinator is only used when `spark.speculation` is 
true.  It can be disabled by setting 
`spark.hadoop.outputCommitCoordination.enabled=false` in SparkConf.

This patch is an updated version of #4155 (by mccheah), which in turn was an 
updated version of this PR.

Closes #4155.

Author: mcheah <mch...@palantir.com>
Author: Josh Rosen <joshro...@databricks.com>

Closes #4066 from JoshRosen/SPARK-4879-sparkhadoopwriter-fix and squashes the 
following commits:

658116b [Josh Rosen] Merge remote-tracking branch 'origin/master' into 
SPARK-4879-sparkhadoopwriter-fix
ed783b2 [Josh Rosen] Address Andrew’s feedback.
e7be65a [Josh Rosen] Merge remote-tracking branch 'origin/master' into 
SPARK-4879-sparkhadoopwriter-fix
14861ea [Josh Rosen] splitID -> partitionID in a few places
ed8b554 [Josh Rosen] Merge remote-tracking branch 'origin/master' into 
SPARK-4879-sparkhadoopwriter-fix
48d5c1c [Josh Rosen] Roll back copiesRunning change in TaskSetManager
3969f5f [Josh Rosen] Re-enable guarding of commit coordination with 
spark.speculation setting.
ede7590 [Josh Rosen] Add test to ensure that a job that denies all commits 
cannot complete successfully.
97da5fe [Josh Rosen] Use actor only for RPC; call methods directly in 
DAGScheduler.
f582574 [Josh Rosen] Some cleanup in OutputCommitCoordinatorSuite
a7c0e29 [Josh Rosen] Create fake TaskInfo using dummy fields instead of Mockito.
997b41b [Josh Rosen] Roll back unnecessary 
DAGSchedulerSingleThreadedProcessLoop refactoring:
459310a [Josh Rosen] Roll back TaskSetManager changes that broke other tests.
dd00b7c [Josh Rosen] Move CommitDeniedException to executors package; remove 
`@DeveloperAPI` annotation.
c79df98 [Josh Rosen] Some misc. code style + doc changes:
f7d69c5 [Josh Rosen] Merge remote-tracking branch 'origin/master' into 
SPARK-4879-sparkhadoopwriter-fix
92e6dc9 [Josh Rosen] Bug fix: use task ID instead of StageID to index into 
authorizedCommitters.
b344bad [Josh Rosen] (Temporarily) re-enable “always coordinate” for 
testing purposes.
0aec91e [Josh Rosen] Only coordinate when speculation is enabled; add 
configuration option to bypass new coordination.
594e41a [mcheah] Fixing a scalastyle error
60a47f4 [mcheah] Writing proper unit test for OutputCommitCoordinator and 
fixing bugs.
d63f63f [mcheah] Fixing compiler error
9fe6495 [mcheah] Fixing scalastyle
1df2a91 [mcheah] Throwing exception if SparkHadoopWriter commit denied
d431144 [mcheah] Using more concurrency to process OutputCommitCoordinator 
requests.
c334255 [mcheah] Properly handling messages that could be sent after actor 
shutdown.
8d5a091 [mcheah] Was mistakenly serializing the accumulator in test suite.
9c6a4fa [mcheah] More OutputCommitCoordinator cleanup on stop()
78eb1b5 [mcheah] Better OutputCommitCoordinatorActor stopping; simpler canCommit
83de900 [mcheah] Making the OutputCommitCoordinatorMessage serializable
abc7db4 [mcheah] TaskInfo can't be null in DAGSchedulerSuite
f135a8e [mcheah] Moving the output commit coordinator from class into method.
1c2b219 [mcheah] Renaming oudated names for test function classes
66a71cd [mcheah] Removing whitespace modifications
6b543ba [mcheah] Removing redundant accumulator in unit test
c9decc6 [mcheah] Scalastyle fixes
bc80770 [mcheah] Unit tests for OutputCommitCoordinator
6e6f748 [mcheah] [SPARK-4879] Use the Spark driver to authorize Hadoop commits.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/1cb37700
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/1cb37700
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/1cb37700

Branch: refs/heads/master
Commit: 1cb37700753437045b15c457b983532cd5a27fa5
Parents: 7e24249
Author: mcheah <mch...@palantir.com>
Authored: Tue Feb 10 20:12:18 2015 -0800
Committer: Andrew Or <and...@databricks.com>
Committed: Tue Feb 10 20:12:18 2015 -0800

----------------------------------------------------------------------
 .../scala/org/apache/spark/SparkContext.scala   |  11 +-
 .../main/scala/org/apache/spark/SparkEnv.scala  |  22 +-
 .../org/apache/spark/SparkHadoopWriter.scala    |  43 +++-
 .../scala/org/apache/spark/TaskEndReason.scala  |  14 ++
 .../spark/executor/CommitDeniedException.scala  |  35 +++
 .../org/apache/spark/executor/Executor.scala    |   5 +
 .../apache/spark/scheduler/DAGScheduler.scala   |  15 +-
 .../scheduler/OutputCommitCoordinator.scala     | 172 +++++++++++++++
 .../spark/scheduler/TaskSchedulerImpl.scala     |   9 +-
 .../apache/spark/scheduler/TaskSetManager.scala |   8 +-
 .../spark/scheduler/DAGSchedulerSuite.scala     |  25 ++-
 .../OutputCommitCoordinatorSuite.scala          | 213 +++++++++++++++++++
 12 files changed, 549 insertions(+), 23 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/1cb37700/core/src/main/scala/org/apache/spark/SparkContext.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala 
b/core/src/main/scala/org/apache/spark/SparkContext.scala
index 53fce6b..24a316e 100644
--- a/core/src/main/scala/org/apache/spark/SparkContext.scala
+++ b/core/src/main/scala/org/apache/spark/SparkContext.scala
@@ -249,7 +249,16 @@ class SparkContext(config: SparkConf) extends Logging with 
ExecutorAllocationCli
   conf.set("spark.executor.id", SparkContext.DRIVER_IDENTIFIER)
 
   // Create the Spark execution environment (cache, map output tracker, etc)
-  private[spark] val env = SparkEnv.createDriverEnv(conf, isLocal, listenerBus)
+
+  // This function allows components created by SparkEnv to be mocked in unit 
tests:
+  private[spark] def createSparkEnv(
+      conf: SparkConf,
+      isLocal: Boolean,
+      listenerBus: LiveListenerBus): SparkEnv = {
+    SparkEnv.createDriverEnv(conf, isLocal, listenerBus)
+  }
+
+  private[spark] val env = createSparkEnv(conf, isLocal, listenerBus)
   SparkEnv.set(env)
 
   // Used to store a URL for each static file/jar together with the file's 
local timestamp

http://git-wip-us.apache.org/repos/asf/spark/blob/1cb37700/core/src/main/scala/org/apache/spark/SparkEnv.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/SparkEnv.scala 
b/core/src/main/scala/org/apache/spark/SparkEnv.scala
index b63bea5..2a0c7e7 100644
--- a/core/src/main/scala/org/apache/spark/SparkEnv.scala
+++ b/core/src/main/scala/org/apache/spark/SparkEnv.scala
@@ -34,7 +34,8 @@ import org.apache.spark.metrics.MetricsSystem
 import org.apache.spark.network.BlockTransferService
 import org.apache.spark.network.netty.NettyBlockTransferService
 import org.apache.spark.network.nio.NioBlockTransferService
-import org.apache.spark.scheduler.LiveListenerBus
+import org.apache.spark.scheduler.{OutputCommitCoordinator, LiveListenerBus}
+import 
org.apache.spark.scheduler.OutputCommitCoordinator.OutputCommitCoordinatorActor
 import org.apache.spark.serializer.Serializer
 import org.apache.spark.shuffle.{ShuffleMemoryManager, ShuffleManager}
 import org.apache.spark.storage._
@@ -67,6 +68,7 @@ class SparkEnv (
     val sparkFilesDir: String,
     val metricsSystem: MetricsSystem,
     val shuffleMemoryManager: ShuffleMemoryManager,
+    val outputCommitCoordinator: OutputCommitCoordinator,
     val conf: SparkConf) extends Logging {
 
   private[spark] var isStopped = false
@@ -88,6 +90,7 @@ class SparkEnv (
     blockManager.stop()
     blockManager.master.stop()
     metricsSystem.stop()
+    outputCommitCoordinator.stop()
     actorSystem.shutdown()
     // Unfortunately Akka's awaitTermination doesn't actually wait for the 
Netty server to shut
     // down, but let's call it anyway in case it gets fixed in a later release
@@ -169,7 +172,8 @@ object SparkEnv extends Logging {
   private[spark] def createDriverEnv(
       conf: SparkConf,
       isLocal: Boolean,
-      listenerBus: LiveListenerBus): SparkEnv = {
+      listenerBus: LiveListenerBus,
+      mockOutputCommitCoordinator: Option[OutputCommitCoordinator] = None): 
SparkEnv = {
     assert(conf.contains("spark.driver.host"), "spark.driver.host is not set 
on the driver!")
     assert(conf.contains("spark.driver.port"), "spark.driver.port is not set 
on the driver!")
     val hostname = conf.get("spark.driver.host")
@@ -181,7 +185,8 @@ object SparkEnv extends Logging {
       port,
       isDriver = true,
       isLocal = isLocal,
-      listenerBus = listenerBus
+      listenerBus = listenerBus,
+      mockOutputCommitCoordinator = mockOutputCommitCoordinator
     )
   }
 
@@ -220,7 +225,8 @@ object SparkEnv extends Logging {
       isDriver: Boolean,
       isLocal: Boolean,
       listenerBus: LiveListenerBus = null,
-      numUsableCores: Int = 0): SparkEnv = {
+      numUsableCores: Int = 0,
+      mockOutputCommitCoordinator: Option[OutputCommitCoordinator] = None): 
SparkEnv = {
 
     // Listener bus is only used on the driver
     if (isDriver) {
@@ -368,6 +374,13 @@ object SparkEnv extends Logging {
         "levels using the RDD.persist() method instead.")
     }
 
+    val outputCommitCoordinator = mockOutputCommitCoordinator.getOrElse {
+      new OutputCommitCoordinator(conf)
+    }
+    val outputCommitCoordinatorActor = 
registerOrLookup("OutputCommitCoordinator",
+      new OutputCommitCoordinatorActor(outputCommitCoordinator))
+    outputCommitCoordinator.coordinatorActor = 
Some(outputCommitCoordinatorActor)
+
     val envInstance = new SparkEnv(
       executorId,
       actorSystem,
@@ -384,6 +397,7 @@ object SparkEnv extends Logging {
       sparkFilesDir,
       metricsSystem,
       shuffleMemoryManager,
+      outputCommitCoordinator,
       conf)
       
     // Add a reference to tmp dir created by driver, we will delete this tmp 
dir when stop() is

http://git-wip-us.apache.org/repos/asf/spark/blob/1cb37700/core/src/main/scala/org/apache/spark/SparkHadoopWriter.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/SparkHadoopWriter.scala 
b/core/src/main/scala/org/apache/spark/SparkHadoopWriter.scala
index 4023759..6eb4537 100644
--- a/core/src/main/scala/org/apache/spark/SparkHadoopWriter.scala
+++ b/core/src/main/scala/org/apache/spark/SparkHadoopWriter.scala
@@ -26,6 +26,7 @@ import org.apache.hadoop.mapred._
 import org.apache.hadoop.fs.FileSystem
 import org.apache.hadoop.fs.Path
 
+import org.apache.spark.executor.CommitDeniedException
 import org.apache.spark.mapred.SparkHadoopMapRedUtil
 import org.apache.spark.rdd.HadoopRDD
 
@@ -105,24 +106,56 @@ class SparkHadoopWriter(@transient jobConf: JobConf)
   def commit() {
     val taCtxt = getTaskContext()
     val cmtr = getOutputCommitter()
-    if (cmtr.needsTaskCommit(taCtxt)) {
+
+    // Called after we have decided to commit
+    def performCommit(): Unit = {
       try {
         cmtr.commitTask(taCtxt)
-        logInfo (taID + ": Committed")
+        logInfo (s"$taID: Committed")
       } catch {
-        case e: IOException => {
+        case e: IOException =>
           logError("Error committing the output of task: " + taID.value, e)
           cmtr.abortTask(taCtxt)
           throw e
+      }
+    }
+
+    // First, check whether the task's output has already been committed by 
some other attempt
+    if (cmtr.needsTaskCommit(taCtxt)) {
+      // The task output needs to be committed, but we don't know whether some 
other task attempt
+      // might be racing to commit the same output partition. Therefore, 
coordinate with the driver
+      // in order to determine whether this attempt can commit (see 
SPARK-4879).
+      val shouldCoordinateWithDriver: Boolean = {
+        val sparkConf = SparkEnv.get.conf
+        // We only need to coordinate with the driver if there are multiple 
concurrent task
+        // attempts, which should only occur if speculation is enabled
+        val speculationEnabled = sparkConf.getBoolean("spark.speculation", 
false)
+        // This (undocumented) setting is an escape-hatch in case the commit 
code introduces bugs
+        sparkConf.getBoolean("spark.hadoop.outputCommitCoordination.enabled", 
speculationEnabled)
+      }
+      if (shouldCoordinateWithDriver) {
+        val outputCommitCoordinator = SparkEnv.get.outputCommitCoordinator
+        val canCommit = outputCommitCoordinator.canCommit(jobID, splitID, 
attemptID)
+        if (canCommit) {
+          performCommit()
+        } else {
+          val msg = s"$taID: Not committed because the driver did not 
authorize commit"
+          logInfo(msg)
+          // We need to abort the task so that the driver can reschedule new 
attempts, if necessary
+          cmtr.abortTask(taCtxt)
+          throw new CommitDeniedException(msg, jobID, splitID, attemptID)
         }
+      } else {
+        // Speculation is disabled or a user has chosen to manually bypass the 
commit coordination
+        performCommit()
       }
     } else {
-      logInfo ("No need to commit output of task: " + taID.value)
+      // Some other attempt committed the output, so we do nothing and signal 
success
+      logInfo(s"No need to commit output of task because 
needsTaskCommit=false: ${taID.value}")
     }
   }
 
   def commitJob() {
-    // always ? Or if cmtr.needsTaskCommit ?
     val cmtr = getOutputCommitter()
     cmtr.commitJob(getJobContext())
   }

http://git-wip-us.apache.org/repos/asf/spark/blob/1cb37700/core/src/main/scala/org/apache/spark/TaskEndReason.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/TaskEndReason.scala 
b/core/src/main/scala/org/apache/spark/TaskEndReason.scala
index af5fd8e..29a5cd5 100644
--- a/core/src/main/scala/org/apache/spark/TaskEndReason.scala
+++ b/core/src/main/scala/org/apache/spark/TaskEndReason.scala
@@ -148,6 +148,20 @@ case object TaskKilled extends TaskFailedReason {
 
 /**
  * :: DeveloperApi ::
+ * Task requested the driver to commit, but was denied.
+ */
+@DeveloperApi
+case class TaskCommitDenied(
+    jobID: Int,
+    partitionID: Int,
+    attemptID: Int)
+  extends TaskFailedReason {
+  override def toErrorString: String = s"TaskCommitDenied (Driver denied task 
commit)" +
+    s" for job: $jobID, partition: $partitionID, attempt: $attemptID"
+}
+
+/**
+ * :: DeveloperApi ::
  * The task failed because the executor that it was running on was lost. This 
may happen because
  * the task crashed the JVM.
  */

http://git-wip-us.apache.org/repos/asf/spark/blob/1cb37700/core/src/main/scala/org/apache/spark/executor/CommitDeniedException.scala
----------------------------------------------------------------------
diff --git 
a/core/src/main/scala/org/apache/spark/executor/CommitDeniedException.scala 
b/core/src/main/scala/org/apache/spark/executor/CommitDeniedException.scala
new file mode 100644
index 0000000..f7604a3
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/executor/CommitDeniedException.scala
@@ -0,0 +1,35 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.executor
+
+import org.apache.spark.{TaskCommitDenied, TaskEndReason}
+
+/**
+ * Exception thrown when a task attempts to commit output to HDFS but is 
denied by the driver.
+ */
+class CommitDeniedException(
+    msg: String,
+    jobID: Int,
+    splitID: Int,
+    attemptID: Int)
+  extends Exception(msg) {
+
+  def toTaskEndReason: TaskEndReason = new TaskCommitDenied(jobID, splitID, 
attemptID)
+
+}
+

http://git-wip-us.apache.org/repos/asf/spark/blob/1cb37700/core/src/main/scala/org/apache/spark/executor/Executor.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/executor/Executor.scala 
b/core/src/main/scala/org/apache/spark/executor/Executor.scala
index 6b22dcd..b684fb7 100644
--- a/core/src/main/scala/org/apache/spark/executor/Executor.scala
+++ b/core/src/main/scala/org/apache/spark/executor/Executor.scala
@@ -253,6 +253,11 @@ private[spark] class Executor(
           execBackend.statusUpdate(taskId, TaskState.KILLED, 
ser.serialize(TaskKilled))
         }
 
+        case cDE: CommitDeniedException => {
+          val reason = cDE.toTaskEndReason
+          execBackend.statusUpdate(taskId, TaskState.FAILED, 
ser.serialize(reason))
+        }
+
         case t: Throwable => {
           // Attempt to exit cleanly by informing the driver of our failure.
           // If anything goes wrong (or this was a fatal exception), we will 
delegate to

http://git-wip-us.apache.org/repos/asf/spark/blob/1cb37700/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala 
b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala
index 1cfe986..7903557 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala
@@ -38,7 +38,7 @@ import org.apache.spark.executor.TaskMetrics
 import org.apache.spark.partial.{ApproximateActionListener, 
ApproximateEvaluator, PartialResult}
 import org.apache.spark.rdd.RDD
 import org.apache.spark.storage._
-import org.apache.spark.util.{CallSite, EventLoop, SystemClock, Clock, Utils}
+import org.apache.spark.util._
 import org.apache.spark.storage.BlockManagerMessages.BlockManagerHeartbeat
 
 /**
@@ -63,7 +63,7 @@ class DAGScheduler(
     mapOutputTracker: MapOutputTrackerMaster,
     blockManagerMaster: BlockManagerMaster,
     env: SparkEnv,
-    clock: Clock = SystemClock)
+    clock: org.apache.spark.util.Clock = SystemClock)
   extends Logging {
 
   def this(sc: SparkContext, taskScheduler: TaskScheduler) = {
@@ -126,6 +126,8 @@ class DAGScheduler(
   private[scheduler] val eventProcessLoop = new 
DAGSchedulerEventProcessLoop(this)
   taskScheduler.setDAGScheduler(this)
 
+  private val outputCommitCoordinator = env.outputCommitCoordinator
+
   // Called by TaskScheduler to report task's starting.
   def taskStarted(task: Task[_], taskInfo: TaskInfo) {
     eventProcessLoop.post(BeginEvent(task, taskInfo))
@@ -808,6 +810,7 @@ class DAGScheduler(
     // will be posted, which should always come after a corresponding 
SparkListenerStageSubmitted
     // event.
     stage.latestInfo = StageInfo.fromStage(stage, 
Some(partitionsToCompute.size))
+    outputCommitCoordinator.stageStart(stage.id)
     listenerBus.post(SparkListenerStageSubmitted(stage.latestInfo, properties))
 
     // TODO: Maybe we can keep the taskBinary in Stage to avoid serializing it 
multiple times.
@@ -865,6 +868,7 @@ class DAGScheduler(
     } else {
       // Because we posted SparkListenerStageSubmitted earlier, we should post
       // SparkListenerStageCompleted here in case there are no tasks to run.
+      outputCommitCoordinator.stageEnd(stage.id)
       listenerBus.post(SparkListenerStageCompleted(stage.latestInfo))
       logDebug("Stage " + stage + " is actually done; %b %d %d".format(
         stage.isAvailable, stage.numAvailableOutputs, stage.numPartitions))
@@ -909,6 +913,9 @@ class DAGScheduler(
     val stageId = task.stageId
     val taskType = Utils.getFormattedClassName(task)
 
+    outputCommitCoordinator.taskCompleted(stageId, task.partitionId,
+      event.taskInfo.attempt, event.reason)
+
     // The success case is dealt with separately below, since we need to 
compute accumulator
     // updates before posting.
     if (event.reason != Success) {
@@ -921,6 +928,7 @@ class DAGScheduler(
       // Skip all the actions if the stage has been cancelled.
       return
     }
+
     val stage = stageIdToStage(task.stageId)
 
     def markStageAsFinished(stage: Stage, errorMessage: Option[String] = None) 
= {
@@ -1073,6 +1081,9 @@ class DAGScheduler(
           handleExecutorLost(bmAddress.executorId, fetchFailed = true, 
Some(task.epoch))
         }
 
+      case commitDenied: TaskCommitDenied =>
+        // Do nothing here, left up to the TaskScheduler to decide how to 
handle denied commits
+
       case ExceptionFailure(className, description, stackTrace, 
fullStackTrace, metrics) =>
         // Do nothing here, left up to the TaskScheduler to decide how to 
handle user failures
 

http://git-wip-us.apache.org/repos/asf/spark/blob/1cb37700/core/src/main/scala/org/apache/spark/scheduler/OutputCommitCoordinator.scala
----------------------------------------------------------------------
diff --git 
a/core/src/main/scala/org/apache/spark/scheduler/OutputCommitCoordinator.scala 
b/core/src/main/scala/org/apache/spark/scheduler/OutputCommitCoordinator.scala
new file mode 100644
index 0000000..759df02
--- /dev/null
+++ 
b/core/src/main/scala/org/apache/spark/scheduler/OutputCommitCoordinator.scala
@@ -0,0 +1,172 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.scheduler
+
+import scala.collection.mutable
+
+import akka.actor.{ActorRef, Actor}
+
+import org.apache.spark._
+import org.apache.spark.util.{AkkaUtils, ActorLogReceive}
+
+private sealed trait OutputCommitCoordinationMessage extends Serializable
+
+private case object StopCoordinator extends OutputCommitCoordinationMessage
+private case class AskPermissionToCommitOutput(stage: Int, task: Long, 
taskAttempt: Long)
+
+/**
+ * Authority that decides whether tasks can commit output to HDFS. Uses a 
"first committer wins"
+ * policy.
+ *
+ * OutputCommitCoordinator is instantiated in both the drivers and executors. 
On executors, it is
+ * configured with a reference to the driver's OutputCommitCoordinatorActor, 
so requests to commit
+ * output will be forwarded to the driver's OutputCommitCoordinator.
+ *
+ * This class was introduced in SPARK-4879; see that JIRA issue (and the 
associated pull requests)
+ * for an extensive design discussion.
+ */
+private[spark] class OutputCommitCoordinator(conf: SparkConf) extends Logging {
+
+  // Initialized by SparkEnv
+  var coordinatorActor: Option[ActorRef] = None
+  private val timeout = AkkaUtils.askTimeout(conf)
+  private val maxAttempts = AkkaUtils.numRetries(conf)
+  private val retryInterval = AkkaUtils.retryWaitMs(conf)
+
+  private type StageId = Int
+  private type PartitionId = Long
+  private type TaskAttemptId = Long
+
+  /**
+   * Map from active stages's id => partition id => task attempt with 
exclusive lock on committing
+   * output for that partition.
+   *
+   * Entries are added to the top-level map when stages start and are removed 
they finish
+   * (either successfully or unsuccessfully).
+   *
+   * Access to this map should be guarded by synchronizing on the 
OutputCommitCoordinator instance.
+   */
+  private val authorizedCommittersByStage: CommittersByStageMap = mutable.Map()
+  private type CommittersByStageMap = mutable.Map[StageId, 
mutable.Map[PartitionId, TaskAttemptId]]
+
+  /**
+   * Called by tasks to ask whether they can commit their output to HDFS.
+   *
+   * If a task attempt has been authorized to commit, then all other attempts 
to commit the same
+   * task will be denied.  If the authorized task attempt fails (e.g. due to 
its executor being
+   * lost), then a subsequent task attempt may be authorized to commit its 
output.
+   *
+   * @param stage the stage number
+   * @param partition the partition number
+   * @param attempt a unique identifier for this task attempt
+   * @return true if this task is authorized to commit, false otherwise
+   */
+  def canCommit(
+      stage: StageId,
+      partition: PartitionId,
+      attempt: TaskAttemptId): Boolean = {
+    val msg = AskPermissionToCommitOutput(stage, partition, attempt)
+    coordinatorActor match {
+      case Some(actor) =>
+        AkkaUtils.askWithReply[Boolean](msg, actor, maxAttempts, 
retryInterval, timeout)
+      case None =>
+        logError(
+          "canCommit called after coordinator was stopped (is SparkEnv 
shutdown in progress)?")
+        false
+    }
+  }
+
+  // Called by DAGScheduler
+  private[scheduler] def stageStart(stage: StageId): Unit = synchronized {
+    authorizedCommittersByStage(stage) = mutable.HashMap[PartitionId, 
TaskAttemptId]()
+  }
+
+  // Called by DAGScheduler
+  private[scheduler] def stageEnd(stage: StageId): Unit = synchronized {
+    authorizedCommittersByStage.remove(stage)
+  }
+
+  // Called by DAGScheduler
+  private[scheduler] def taskCompleted(
+      stage: StageId,
+      partition: PartitionId,
+      attempt: TaskAttemptId,
+      reason: TaskEndReason): Unit = synchronized {
+    val authorizedCommitters = authorizedCommittersByStage.getOrElse(stage, {
+      logDebug(s"Ignoring task completion for completed stage")
+      return
+    })
+    reason match {
+      case Success =>
+      // The task output has been committed successfully
+      case denied: TaskCommitDenied =>
+        logInfo(
+          s"Task was denied committing, stage: $stage, partition: $partition, 
attempt: $attempt")
+      case otherReason =>
+        logDebug(s"Authorized committer $attempt (stage=$stage, 
partition=$partition) failed;" +
+          s" clearing lock")
+        authorizedCommitters.remove(partition)
+    }
+  }
+
+  def stop(): Unit = synchronized {
+    coordinatorActor.foreach(_ ! StopCoordinator)
+    coordinatorActor = None
+    authorizedCommittersByStage.clear()
+  }
+
+  // Marked private[scheduler] instead of private so this can be mocked in 
tests
+  private[scheduler] def handleAskPermissionToCommit(
+      stage: StageId,
+      partition: PartitionId,
+      attempt: TaskAttemptId): Boolean = synchronized {
+    authorizedCommittersByStage.get(stage) match {
+      case Some(authorizedCommitters) =>
+        authorizedCommitters.get(partition) match {
+          case Some(existingCommitter) =>
+            logDebug(s"Denying $attempt to commit for stage=$stage, 
partition=$partition; " +
+              s"existingCommitter = $existingCommitter")
+            false
+          case None =>
+            logDebug(s"Authorizing $attempt to commit for stage=$stage, 
partition=$partition")
+            authorizedCommitters(partition) = attempt
+            true
+        }
+      case None =>
+        logDebug(s"Stage $stage has completed, so not allowing task attempt 
$attempt to commit")
+        false
+    }
+  }
+}
+
+private[spark] object OutputCommitCoordinator {
+
+  // This actor is used only for RPC
+  class OutputCommitCoordinatorActor(outputCommitCoordinator: 
OutputCommitCoordinator)
+    extends Actor with ActorLogReceive with Logging {
+
+    override def receiveWithLogging = {
+      case AskPermissionToCommitOutput(stage, partition, taskAttempt) =>
+        sender ! outputCommitCoordinator.handleAskPermissionToCommit(stage, 
partition, taskAttempt)
+      case StopCoordinator =>
+        logInfo("OutputCommitCoordinator stopped!")
+        context.stop(self)
+        sender ! true
+    }
+  }
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/1cb37700/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala
----------------------------------------------------------------------
diff --git 
a/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala 
b/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala
index 79f84e7..54f8fcf 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala
@@ -158,7 +158,7 @@ private[spark] class TaskSchedulerImpl(
     val tasks = taskSet.tasks
     logInfo("Adding task set " + taskSet.id + " with " + tasks.length + " 
tasks")
     this.synchronized {
-      val manager = new TaskSetManager(this, taskSet, maxTaskFailures)
+      val manager = createTaskSetManager(taskSet, maxTaskFailures)
       activeTaskSets(taskSet.id) = manager
       schedulableBuilder.addTaskSetManager(manager, manager.taskSet.properties)
 
@@ -180,6 +180,13 @@ private[spark] class TaskSchedulerImpl(
     backend.reviveOffers()
   }
 
+  // Label as private[scheduler] to allow tests to swap in different task set 
managers if necessary
+  private[scheduler] def createTaskSetManager(
+      taskSet: TaskSet,
+      maxTaskFailures: Int): TaskSetManager = {
+    new TaskSetManager(this, taskSet, maxTaskFailures)
+  }
+
   override def cancelTasks(stageId: Int, interruptThread: Boolean): Unit = 
synchronized {
     logInfo("Cancelling stage " + stageId)
     activeTaskSets.find(_._2.stageId == stageId).foreach { case (_, tsm) =>

http://git-wip-us.apache.org/repos/asf/spark/blob/1cb37700/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala
----------------------------------------------------------------------
diff --git 
a/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala 
b/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala
index 55024ec..99a5f71 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala
@@ -292,7 +292,8 @@ private[spark] class TaskSetManager(
    * an attempt running on this host, in case the host is slow. In addition, 
the task should meet
    * the given locality constraint.
    */
-  private def dequeueSpeculativeTask(execId: String, host: String, locality: 
TaskLocality.Value)
+  // Labeled as protected to allow tests to override providing speculative 
tasks if necessary
+  protected def dequeueSpeculativeTask(execId: String, host: String, locality: 
TaskLocality.Value)
     : Option[(Int, TaskLocality.Value)] =
   {
     speculatableTasks.retain(index => !successful(index)) // Remove finished 
tasks from set
@@ -708,7 +709,10 @@ private[spark] class TaskSetManager(
       put(info.executorId, clock.getTime())
     sched.dagScheduler.taskEnded(tasks(index), reason, null, null, info, 
taskMetrics)
     addPendingTask(index)
-    if (!isZombie && state != TaskState.KILLED) {
+    if (!isZombie && state != TaskState.KILLED && 
!reason.isInstanceOf[TaskCommitDenied]) {
+      // If a task failed because its attempt to commit was denied, do not 
count this failure
+      // towards failing the stage. This is intended to prevent spurious stage 
failures in cases
+      // where many speculative tasks are launched and denied to commit.
       assert (null != failureReason)
       numFailures(index) += 1
       if (numFailures(index) >= maxTaskFailures) {

http://git-wip-us.apache.org/repos/asf/spark/blob/1cb37700/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala
----------------------------------------------------------------------
diff --git 
a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala 
b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala
index eb11621..9d0c127 100644
--- a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala
+++ b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala
@@ -208,7 +208,7 @@ class DAGSchedulerSuite extends FunSuiteLike  with 
BeforeAndAfter with LocalSpar
     assert(taskSet.tasks.size >= results.size)
     for ((result, i) <- results.zipWithIndex) {
       if (i < taskSet.tasks.size) {
-        runEvent(CompletionEvent(taskSet.tasks(i), result._1, result._2, null, 
null, null))
+        runEvent(CompletionEvent(taskSet.tasks(i), result._1, result._2, null, 
createFakeTaskInfo(), null))
       }
     }
   }
@@ -219,7 +219,7 @@ class DAGSchedulerSuite extends FunSuiteLike  with 
BeforeAndAfter with LocalSpar
     for ((result, i) <- results.zipWithIndex) {
       if (i < taskSet.tasks.size) {
         runEvent(CompletionEvent(taskSet.tasks(i), result._1, result._2,
-          Map[Long, Any]((accumId, 1)), null, null))
+          Map[Long, Any]((accumId, 1)), createFakeTaskInfo(), null))
       }
     }
   }
@@ -476,7 +476,7 @@ class DAGSchedulerSuite extends FunSuiteLike  with 
BeforeAndAfter with LocalSpar
       FetchFailed(makeBlockManagerId("hostA"), shuffleId, 0, 0, "ignored"),
       null,
       Map[Long, Any](),
-      null,
+      createFakeTaskInfo(),
       null))
     assert(sc.listenerBus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS))
     assert(sparkListener.failedStages.contains(1))
@@ -487,7 +487,7 @@ class DAGSchedulerSuite extends FunSuiteLike  with 
BeforeAndAfter with LocalSpar
       FetchFailed(makeBlockManagerId("hostA"), shuffleId, 1, 1, "ignored"),
       null,
       Map[Long, Any](),
-      null,
+      createFakeTaskInfo(),
       null))
     // The SparkListener should not receive redundant failure events.
     assert(sc.listenerBus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS))
@@ -507,14 +507,14 @@ class DAGSchedulerSuite extends FunSuiteLike  with 
BeforeAndAfter with LocalSpar
     assert(newEpoch > oldEpoch)
     val taskSet = taskSets(0)
     // should be ignored for being too old
-    runEvent(CompletionEvent(taskSet.tasks(0), Success, makeMapStatus("hostA", 
1), null, null, null))
+    runEvent(CompletionEvent(taskSet.tasks(0), Success, makeMapStatus("hostA", 
1), null, createFakeTaskInfo(), null))
     // should work because it's a non-failed host
-    runEvent(CompletionEvent(taskSet.tasks(0), Success, makeMapStatus("hostB", 
1), null, null, null))
+    runEvent(CompletionEvent(taskSet.tasks(0), Success, makeMapStatus("hostB", 
1), null, createFakeTaskInfo(), null))
     // should be ignored for being too old
-    runEvent(CompletionEvent(taskSet.tasks(0), Success, makeMapStatus("hostA", 
1), null, null, null))
+    runEvent(CompletionEvent(taskSet.tasks(0), Success, makeMapStatus("hostA", 
1), null, createFakeTaskInfo(), null))
     // should work because it's a new epoch
     taskSet.tasks(1).epoch = newEpoch
-    runEvent(CompletionEvent(taskSet.tasks(1), Success, makeMapStatus("hostA", 
1), null, null, null))
+    runEvent(CompletionEvent(taskSet.tasks(1), Success, makeMapStatus("hostA", 
1), null, createFakeTaskInfo(), null))
     assert(mapOutputTracker.getServerStatuses(shuffleId, 0).map(_._1) ===
            Array(makeBlockManagerId("hostB"), makeBlockManagerId("hostA")))
     complete(taskSets(1), Seq((Success, 42), (Success, 43)))
@@ -766,5 +766,14 @@ class DAGSchedulerSuite extends FunSuiteLike  with 
BeforeAndAfter with LocalSpar
     assert(scheduler.shuffleToMapStage.isEmpty)
     assert(scheduler.waitingStages.isEmpty)
   }
+
+  // Nothing in this test should break if the task info's fields are null, but
+  // OutputCommitCoordinator requires the task info itself to not be null.
+  private def createFakeTaskInfo(): TaskInfo = {
+    val info = new TaskInfo(0, 0, 0, 0L, "", "", TaskLocality.ANY, false)
+    info.finishTime = 1  // to prevent spurious errors in JobProgressListener
+    info
+  }
+
 }
 

http://git-wip-us.apache.org/repos/asf/spark/blob/1cb37700/core/src/test/scala/org/apache/spark/scheduler/OutputCommitCoordinatorSuite.scala
----------------------------------------------------------------------
diff --git 
a/core/src/test/scala/org/apache/spark/scheduler/OutputCommitCoordinatorSuite.scala
 
b/core/src/test/scala/org/apache/spark/scheduler/OutputCommitCoordinatorSuite.scala
new file mode 100644
index 0000000..3cc860c
--- /dev/null
+++ 
b/core/src/test/scala/org/apache/spark/scheduler/OutputCommitCoordinatorSuite.scala
@@ -0,0 +1,213 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.scheduler
+
+import java.io.File
+import java.util.concurrent.TimeoutException
+
+import org.mockito.Matchers
+import org.mockito.Mockito._
+import org.mockito.invocation.InvocationOnMock
+import org.mockito.stubbing.Answer
+import org.scalatest.{BeforeAndAfter, FunSuite}
+
+import org.apache.hadoop.mapred.{TaskAttemptID, JobConf, TaskAttemptContext, 
OutputCommitter}
+
+import org.apache.spark._
+import org.apache.spark.rdd.{RDD, FakeOutputCommitter}
+import org.apache.spark.util.Utils
+
+import scala.concurrent.Await
+import scala.concurrent.duration._
+import scala.language.postfixOps
+
+/**
+ * Unit tests for the output commit coordination functionality.
+ *
+ * The unit test makes both the original task and the speculated task
+ * attempt to commit, where committing is emulated by creating a
+ * directory. If both tasks create directories then the end result is
+ * a failure.
+ *
+ * Note that there are some aspects of this test that are less than ideal.
+ * In particular, the test mocks the speculation-dequeuing logic to always
+ * dequeue a task and consider it as speculated. Immediately after initially
+ * submitting the tasks and calling reviveOffers(), reviveOffers() is invoked
+ * again to pick up the speculated task. This may be hacking the original
+ * behavior in too much of an unrealistic fashion.
+ *
+ * Also, the validation is done by checking the number of files in a directory.
+ * Ideally, an accumulator would be used for this, where we could increment
+ * the accumulator in the output committer's commitTask() call. If the call to
+ * commitTask() was called twice erroneously then the test would ideally fail 
because
+ * the accumulator would be incremented twice.
+ *
+ * The problem with this test implementation is that when both a speculated 
task and
+ * its original counterpart complete, only one of the accumulator's increments 
is
+ * captured. This results in a paradox where if the OutputCommitCoordinator 
logic
+ * was not in SparkHadoopWriter, the tests would still pass because only one 
of the
+ * increments would be captured even though the commit in both tasks was 
executed
+ * erroneously.
+ */
+class OutputCommitCoordinatorSuite extends FunSuite with BeforeAndAfter {
+
+  var outputCommitCoordinator: OutputCommitCoordinator = null
+  var tempDir: File = null
+  var sc: SparkContext = null
+
+  before {
+    tempDir = Utils.createTempDir()
+    val conf = new SparkConf()
+      .setMaster("local[4]")
+      .setAppName(classOf[OutputCommitCoordinatorSuite].getSimpleName)
+      .set("spark.speculation", "true")
+    sc = new SparkContext(conf) {
+      override private[spark] def createSparkEnv(
+          conf: SparkConf,
+          isLocal: Boolean,
+          listenerBus: LiveListenerBus): SparkEnv = {
+        outputCommitCoordinator = spy(new OutputCommitCoordinator(conf))
+        // Use Mockito.spy() to maintain the default infrastructure everywhere 
else.
+        // This mocking allows us to control the coordinator responses in test 
cases.
+        SparkEnv.createDriverEnv(conf, isLocal, listenerBus, 
Some(outputCommitCoordinator))
+      }
+    }
+    // Use Mockito.spy() to maintain the default infrastructure everywhere else
+    val mockTaskScheduler = 
spy(sc.taskScheduler.asInstanceOf[TaskSchedulerImpl])
+
+    doAnswer(new Answer[Unit]() {
+      override def answer(invoke: InvocationOnMock): Unit = {
+        // Submit the tasks, then force the task scheduler to dequeue the
+        // speculated task
+        invoke.callRealMethod()
+        mockTaskScheduler.backend.reviveOffers()
+      }
+    }).when(mockTaskScheduler).submitTasks(Matchers.any())
+
+    doAnswer(new Answer[TaskSetManager]() {
+      override def answer(invoke: InvocationOnMock): TaskSetManager = {
+        val taskSet = invoke.getArguments()(0).asInstanceOf[TaskSet]
+        new TaskSetManager(mockTaskScheduler, taskSet, 4) {
+          var hasDequeuedSpeculatedTask = false
+          override def dequeueSpeculativeTask(
+              execId: String,
+              host: String,
+              locality: TaskLocality.Value): Option[(Int, TaskLocality.Value)] 
= {
+            if (!hasDequeuedSpeculatedTask) {
+              hasDequeuedSpeculatedTask = true
+              Some(0, TaskLocality.PROCESS_LOCAL)
+            } else {
+              None
+            }
+          }
+        }
+      }
+    }).when(mockTaskScheduler).createTaskSetManager(Matchers.any(), 
Matchers.any())
+
+    sc.taskScheduler = mockTaskScheduler
+    val dagSchedulerWithMockTaskScheduler = new DAGScheduler(sc, 
mockTaskScheduler)
+    sc.taskScheduler.setDAGScheduler(dagSchedulerWithMockTaskScheduler)
+    sc.dagScheduler = dagSchedulerWithMockTaskScheduler
+  }
+
+  after {
+    sc.stop()
+    tempDir.delete()
+    outputCommitCoordinator = null
+  }
+
+  test("Only one of two duplicate commit tasks should commit") {
+    val rdd = sc.parallelize(Seq(1), 1)
+    sc.runJob(rdd, 
OutputCommitFunctions(tempDir.getAbsolutePath).commitSuccessfully _,
+      0 until rdd.partitions.size, allowLocal = false)
+    assert(tempDir.list().size === 1)
+  }
+
+  test("If commit fails, if task is retried it should not be locked, and will 
succeed.") {
+    val rdd = sc.parallelize(Seq(1), 1)
+    sc.runJob(rdd, 
OutputCommitFunctions(tempDir.getAbsolutePath).failFirstCommitAttempt _,
+      0 until rdd.partitions.size, allowLocal = false)
+    assert(tempDir.list().size === 1)
+  }
+
+  test("Job should not complete if all commits are denied") {
+    // Create a mock OutputCommitCoordinator that denies all attempts to commit
+    doReturn(false).when(outputCommitCoordinator).handleAskPermissionToCommit(
+      Matchers.any(), Matchers.any(), Matchers.any())
+    val rdd: RDD[Int] = sc.parallelize(Seq(1), 1)
+    def resultHandler(x: Int, y: Unit): Unit = {}
+    val futureAction: SimpleFutureAction[Unit] = sc.submitJob[Int, Unit, 
Unit](rdd,
+      OutputCommitFunctions(tempDir.getAbsolutePath).commitSuccessfully,
+      0 until rdd.partitions.size, resultHandler, 0)
+    // It's an error if the job completes successfully even though no 
committer was authorized,
+    // so throw an exception if the job was allowed to complete.
+    intercept[TimeoutException] {
+      Await.result(futureAction, 5 seconds)
+    }
+    assert(tempDir.list().size === 0)
+  }
+}
+
+/**
+ * Class with methods that can be passed to runJob to test commits with a mock 
committer.
+ */
+private case class OutputCommitFunctions(tempDirPath: String) {
+
+  // Mock output committer that simulates a successful commit (after commit is 
authorized)
+  private def successfulOutputCommitter = new FakeOutputCommitter {
+    override def commitTask(context: TaskAttemptContext): Unit = {
+      Utils.createDirectory(tempDirPath)
+    }
+  }
+
+  // Mock output committer that simulates a failed commit (after commit is 
authorized)
+  private def failingOutputCommitter = new FakeOutputCommitter {
+    override def commitTask(taskAttemptContext: TaskAttemptContext) {
+      throw new RuntimeException
+    }
+  }
+
+  def commitSuccessfully(iter: Iterator[Int]): Unit = {
+    val ctx = TaskContext.get()
+    runCommitWithProvidedCommitter(ctx, iter, successfulOutputCommitter)
+  }
+
+  def failFirstCommitAttempt(iter: Iterator[Int]): Unit = {
+    val ctx = TaskContext.get()
+    runCommitWithProvidedCommitter(ctx, iter,
+      if (ctx.attemptNumber == 0) failingOutputCommitter else 
successfulOutputCommitter)
+  }
+
+  private def runCommitWithProvidedCommitter(
+      ctx: TaskContext,
+      iter: Iterator[Int],
+      outputCommitter: OutputCommitter): Unit = {
+    def jobConf = new JobConf {
+      override def getOutputCommitter(): OutputCommitter = outputCommitter
+    }
+    val sparkHadoopWriter = new SparkHadoopWriter(jobConf) {
+      override def newTaskAttemptContext(
+        conf: JobConf,
+        attemptId: TaskAttemptID): TaskAttemptContext = {
+        mock(classOf[TaskAttemptContext])
+      }
+    }
+    sparkHadoopWriter.setup(ctx.stageId, ctx.partitionId, ctx.attemptNumber)
+    sparkHadoopWriter.commit()
+  }
+}


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to