Repository: spark
Updated Branches:
  refs/heads/master ea5ae2705 -> ad960885b


[SPARK-8029] Robust shuffle writer

Currently, all the shuffle writer will write to target path directly, the file 
could be corrupted by other attempt of the same partition on the same executor. 
They should write to temporary file then rename to target path, as what we do 
in output committer. In order to make the rename atomic, the temporary file 
should be created in the same local directory (FileSystem).

This PR is based on #9214 , thanks to squito . Closes #9214

Author: Davies Liu <dav...@databricks.com>

Closes #9610 from davies/safe_shuffle.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/ad960885
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/ad960885
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/ad960885

Branch: refs/heads/master
Commit: ad960885bfee7850c18eb5338546cecf2b2e9876
Parents: ea5ae27
Author: Davies Liu <dav...@databricks.com>
Authored: Thu Nov 12 22:44:57 2015 -0800
Committer: Davies Liu <davies....@gmail.com>
Committed: Thu Nov 12 22:44:57 2015 -0800

----------------------------------------------------------------------
 .../sort/BypassMergeSortShuffleWriter.java      |   9 +-
 .../spark/shuffle/sort/UnsafeShuffleWriter.java |  13 ++-
 .../shuffle/FileShuffleBlockResolver.scala      |  17 +--
 .../shuffle/IndexShuffleBlockResolver.scala     | 102 +++++++++++++++--
 .../spark/shuffle/hash/HashShuffleWriter.scala  |  25 ++++
 .../spark/shuffle/sort/SortShuffleWriter.scala  |  11 +-
 .../org/apache/spark/storage/BlockManager.scala |   9 +-
 .../spark/storage/DiskBlockObjectWriter.scala   |   5 +-
 .../scala/org/apache/spark/util/Utils.scala     |  12 +-
 .../spark/util/collection/ExternalSorter.scala  |   1 -
 .../sort/IndexShuffleBlockResolverSuite.scala   | 114 +++++++++++++++++++
 .../shuffle/sort/UnsafeShuffleWriterSuite.java  |   9 +-
 .../map/AbstractBytesToBytesMapSuite.java       |   3 +-
 .../unsafe/sort/UnsafeExternalSorterSuite.java  |   3 +-
 .../scala/org/apache/spark/ShuffleSuite.scala   | 107 ++++++++++++++++-
 .../BypassMergeSortShuffleWriterSuite.scala     |  14 ++-
 16 files changed, 402 insertions(+), 52 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/ad960885/core/src/main/java/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriter.java
----------------------------------------------------------------------
diff --git 
a/core/src/main/java/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriter.java
 
b/core/src/main/java/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriter.java
index ee82d67..a1a1fb0 100644
--- 
a/core/src/main/java/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriter.java
+++ 
b/core/src/main/java/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriter.java
@@ -125,7 +125,7 @@ final class BypassMergeSortShuffleWriter<K, V> extends 
ShuffleWriter<K, V> {
     assert (partitionWriters == null);
     if (!records.hasNext()) {
       partitionLengths = new long[numPartitions];
-      shuffleBlockResolver.writeIndexFile(shuffleId, mapId, partitionLengths);
+      shuffleBlockResolver.writeIndexFileAndCommit(shuffleId, mapId, 
partitionLengths, null);
       mapStatus = MapStatus$.MODULE$.apply(blockManager.shuffleServerId(), 
partitionLengths);
       return;
     }
@@ -155,9 +155,10 @@ final class BypassMergeSortShuffleWriter<K, V> extends 
ShuffleWriter<K, V> {
       writer.commitAndClose();
     }
 
-    partitionLengths =
-      writePartitionedFile(shuffleBlockResolver.getDataFile(shuffleId, mapId));
-    shuffleBlockResolver.writeIndexFile(shuffleId, mapId, partitionLengths);
+    File output = shuffleBlockResolver.getDataFile(shuffleId, mapId);
+    File tmp = Utils.tempFileWith(output);
+    partitionLengths = writePartitionedFile(tmp);
+    shuffleBlockResolver.writeIndexFileAndCommit(shuffleId, mapId, 
partitionLengths, tmp);
     mapStatus = MapStatus$.MODULE$.apply(blockManager.shuffleServerId(), 
partitionLengths);
   }
 

http://git-wip-us.apache.org/repos/asf/spark/blob/ad960885/core/src/main/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriter.java
----------------------------------------------------------------------
diff --git 
a/core/src/main/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriter.java 
b/core/src/main/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriter.java
index 6a0a89e..744c300 100644
--- a/core/src/main/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriter.java
+++ b/core/src/main/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriter.java
@@ -41,7 +41,7 @@ import org.apache.spark.annotation.Private;
 import org.apache.spark.executor.ShuffleWriteMetrics;
 import org.apache.spark.io.CompressionCodec;
 import org.apache.spark.io.CompressionCodec$;
-import org.apache.spark.io.LZFCompressionCodec;
+import org.apache.spark.memory.TaskMemoryManager;
 import org.apache.spark.network.util.LimitedInputStream;
 import org.apache.spark.scheduler.MapStatus;
 import org.apache.spark.scheduler.MapStatus$;
@@ -53,7 +53,7 @@ import org.apache.spark.shuffle.ShuffleWriter;
 import org.apache.spark.storage.BlockManager;
 import org.apache.spark.storage.TimeTrackingOutputStream;
 import org.apache.spark.unsafe.Platform;
-import org.apache.spark.memory.TaskMemoryManager;
+import org.apache.spark.util.Utils;
 
 @Private
 public class UnsafeShuffleWriter<K, V> extends ShuffleWriter<K, V> {
@@ -206,8 +206,10 @@ public class UnsafeShuffleWriter<K, V> extends 
ShuffleWriter<K, V> {
     final SpillInfo[] spills = sorter.closeAndGetSpills();
     sorter = null;
     final long[] partitionLengths;
+    final File output = shuffleBlockResolver.getDataFile(shuffleId, mapId);
+    final File tmp = Utils.tempFileWith(output);
     try {
-      partitionLengths = mergeSpills(spills);
+      partitionLengths = mergeSpills(spills, tmp);
     } finally {
       for (SpillInfo spill : spills) {
         if (spill.file.exists() && ! spill.file.delete()) {
@@ -215,7 +217,7 @@ public class UnsafeShuffleWriter<K, V> extends 
ShuffleWriter<K, V> {
         }
       }
     }
-    shuffleBlockResolver.writeIndexFile(shuffleId, mapId, partitionLengths);
+    shuffleBlockResolver.writeIndexFileAndCommit(shuffleId, mapId, 
partitionLengths, tmp);
     mapStatus = MapStatus$.MODULE$.apply(blockManager.shuffleServerId(), 
partitionLengths);
   }
 
@@ -248,8 +250,7 @@ public class UnsafeShuffleWriter<K, V> extends 
ShuffleWriter<K, V> {
    *
    * @return the partition lengths in the merged file.
    */
-  private long[] mergeSpills(SpillInfo[] spills) throws IOException {
-    final File outputFile = shuffleBlockResolver.getDataFile(shuffleId, mapId);
+  private long[] mergeSpills(SpillInfo[] spills, File outputFile) throws 
IOException {
     final boolean compressionEnabled = 
sparkConf.getBoolean("spark.shuffle.compress", true);
     final CompressionCodec compressionCodec = 
CompressionCodec$.MODULE$.createCodec(sparkConf);
     final boolean fastMergeEnabled =

http://git-wip-us.apache.org/repos/asf/spark/blob/ad960885/core/src/main/scala/org/apache/spark/shuffle/FileShuffleBlockResolver.scala
----------------------------------------------------------------------
diff --git 
a/core/src/main/scala/org/apache/spark/shuffle/FileShuffleBlockResolver.scala 
b/core/src/main/scala/org/apache/spark/shuffle/FileShuffleBlockResolver.scala
index cd253a7..39fadd8 100644
--- 
a/core/src/main/scala/org/apache/spark/shuffle/FileShuffleBlockResolver.scala
+++ 
b/core/src/main/scala/org/apache/spark/shuffle/FileShuffleBlockResolver.scala
@@ -21,13 +21,13 @@ import java.util.concurrent.ConcurrentLinkedQueue
 
 import scala.collection.JavaConverters._
 
-import org.apache.spark.{Logging, SparkConf, SparkEnv}
 import org.apache.spark.executor.ShuffleWriteMetrics
 import org.apache.spark.network.buffer.{FileSegmentManagedBuffer, 
ManagedBuffer}
 import org.apache.spark.network.netty.SparkTransportConf
 import org.apache.spark.serializer.Serializer
 import org.apache.spark.storage._
-import org.apache.spark.util.{MetadataCleaner, MetadataCleanerType, 
TimeStampedHashMap}
+import org.apache.spark.util.{MetadataCleaner, MetadataCleanerType, 
TimeStampedHashMap, Utils}
+import org.apache.spark.{Logging, SparkConf, SparkEnv}
 
 /** A group of writers for a ShuffleMapTask, one writer per reducer. */
 private[spark] trait ShuffleWriterGroup {
@@ -84,17 +84,8 @@ private[spark] class FileShuffleBlockResolver(conf: 
SparkConf)
         Array.tabulate[DiskBlockObjectWriter](numReducers) { bucketId =>
           val blockId = ShuffleBlockId(shuffleId, mapId, bucketId)
           val blockFile = blockManager.diskBlockManager.getFile(blockId)
-          // Because of previous failures, the shuffle file may already exist 
on this machine.
-          // If so, remove it.
-          if (blockFile.exists) {
-            if (blockFile.delete()) {
-              logInfo(s"Removed existing shuffle file $blockFile")
-            } else {
-              logWarning(s"Failed to remove existing shuffle file $blockFile")
-            }
-          }
-          blockManager.getDiskWriter(blockId, blockFile, serializerInstance, 
bufferSize,
-            writeMetrics)
+          val tmp = Utils.tempFileWith(blockFile)
+          blockManager.getDiskWriter(blockId, tmp, serializerInstance, 
bufferSize, writeMetrics)
         }
       }
       // Creating the file to write to and creating a disk writer both involve 
interacting with

http://git-wip-us.apache.org/repos/asf/spark/blob/ad960885/core/src/main/scala/org/apache/spark/shuffle/IndexShuffleBlockResolver.scala
----------------------------------------------------------------------
diff --git 
a/core/src/main/scala/org/apache/spark/shuffle/IndexShuffleBlockResolver.scala 
b/core/src/main/scala/org/apache/spark/shuffle/IndexShuffleBlockResolver.scala
index 5e4c2b5..05b1eed 100644
--- 
a/core/src/main/scala/org/apache/spark/shuffle/IndexShuffleBlockResolver.scala
+++ 
b/core/src/main/scala/org/apache/spark/shuffle/IndexShuffleBlockResolver.scala
@@ -21,13 +21,12 @@ import java.io._
 
 import com.google.common.io.ByteStreams
 
-import org.apache.spark.{SparkConf, SparkEnv, Logging}
 import org.apache.spark.network.buffer.{FileSegmentManagedBuffer, 
ManagedBuffer}
 import org.apache.spark.network.netty.SparkTransportConf
+import org.apache.spark.shuffle.IndexShuffleBlockResolver.NOOP_REDUCE_ID
 import org.apache.spark.storage._
 import org.apache.spark.util.Utils
-
-import IndexShuffleBlockResolver.NOOP_REDUCE_ID
+import org.apache.spark.{SparkEnv, Logging, SparkConf}
 
 /**
  * Create and maintain the shuffle blocks' mapping between logic block and 
physical file location.
@@ -40,10 +39,13 @@ import IndexShuffleBlockResolver.NOOP_REDUCE_ID
  */
 // Note: Changes to the format in this file should be kept in sync with
 // 
org.apache.spark.network.shuffle.ExternalShuffleBlockResolver#getSortBasedShuffleBlockData().
-private[spark] class IndexShuffleBlockResolver(conf: SparkConf) extends 
ShuffleBlockResolver
+private[spark] class IndexShuffleBlockResolver(
+    conf: SparkConf,
+    _blockManager: BlockManager = null)
+  extends ShuffleBlockResolver
   with Logging {
 
-  private lazy val blockManager = SparkEnv.get.blockManager
+  private lazy val blockManager = 
Option(_blockManager).getOrElse(SparkEnv.get.blockManager)
 
   private val transportConf = SparkTransportConf.fromSparkConf(conf)
 
@@ -75,13 +77,68 @@ private[spark] class IndexShuffleBlockResolver(conf: 
SparkConf) extends ShuffleB
   }
 
   /**
+   * Check whether the given index and data files match each other.
+   * If so, return the partition lengths in the data file. Otherwise return 
null.
+   */
+  private def checkIndexAndDataFile(index: File, data: File, blocks: Int): 
Array[Long] = {
+    // the index file should have `block + 1` longs as offset.
+    if (index.length() != (blocks + 1) * 8) {
+      return null
+    }
+    val lengths = new Array[Long](blocks)
+    // Read the lengths of blocks
+    val in = try {
+      new DataInputStream(new BufferedInputStream(new FileInputStream(index)))
+    } catch {
+      case e: IOException =>
+        return null
+    }
+    try {
+      // Convert the offsets into lengths of each block
+      var offset = in.readLong()
+      if (offset != 0L) {
+        return null
+      }
+      var i = 0
+      while (i < blocks) {
+        val off = in.readLong()
+        lengths(i) = off - offset
+        offset = off
+        i += 1
+      }
+    } catch {
+      case e: IOException =>
+        return null
+    } finally {
+      in.close()
+    }
+
+    // the size of data file should match with index file
+    if (data.length() == lengths.sum) {
+      lengths
+    } else {
+      null
+    }
+  }
+
+  /**
    * Write an index file with the offsets of each block, plus a final offset 
at the end for the
    * end of the output file. This will be used by getBlockData to figure out 
where each block
    * begins and ends.
+   *
+   * It will commit the data and index file as an atomic operation, use the 
existing ones, or
+   * replace them with new ones.
+   *
+   * Note: the `lengths` will be updated to match the existing index file if 
use the existing ones.
    * */
-  def writeIndexFile(shuffleId: Int, mapId: Int, lengths: Array[Long]): Unit = 
{
+  def writeIndexFileAndCommit(
+      shuffleId: Int,
+      mapId: Int,
+      lengths: Array[Long],
+      dataTmp: File): Unit = {
     val indexFile = getIndexFile(shuffleId, mapId)
-    val out = new DataOutputStream(new BufferedOutputStream(new 
FileOutputStream(indexFile)))
+    val indexTmp = Utils.tempFileWith(indexFile)
+    val out = new DataOutputStream(new BufferedOutputStream(new 
FileOutputStream(indexTmp)))
     Utils.tryWithSafeFinally {
       // We take in lengths of each block, need to convert it to offsets.
       var offset = 0L
@@ -93,6 +150,37 @@ private[spark] class IndexShuffleBlockResolver(conf: 
SparkConf) extends ShuffleB
     } {
       out.close()
     }
+
+    val dataFile = getDataFile(shuffleId, mapId)
+    // There is only one IndexShuffleBlockResolver per executor, this 
synchronization make sure
+    // the following check and rename are atomic.
+    synchronized {
+      val existingLengths = checkIndexAndDataFile(indexFile, dataFile, 
lengths.length)
+      if (existingLengths != null) {
+        // Another attempt for the same task has already written our map 
outputs successfully,
+        // so just use the existing partition lengths and delete our temporary 
map outputs.
+        System.arraycopy(existingLengths, 0, lengths, 0, lengths.length)
+        if (dataTmp != null && dataTmp.exists()) {
+          dataTmp.delete()
+        }
+        indexTmp.delete()
+      } else {
+        // This is the first successful attempt in writing the map outputs for 
this task,
+        // so override any existing index and data files with the ones we 
wrote.
+        if (indexFile.exists()) {
+          indexFile.delete()
+        }
+        if (dataFile.exists()) {
+          dataFile.delete()
+        }
+        if (!indexTmp.renameTo(indexFile)) {
+          throw new IOException("fail to rename file " + indexTmp + " to " + 
indexFile)
+        }
+        if (dataTmp != null && dataTmp.exists() && 
!dataTmp.renameTo(dataFile)) {
+          throw new IOException("fail to rename file " + dataTmp + " to " + 
dataFile)
+        }
+      }
+    }
   }
 
   override def getBlockData(blockId: ShuffleBlockId): ManagedBuffer = {

http://git-wip-us.apache.org/repos/asf/spark/blob/ad960885/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleWriter.scala
----------------------------------------------------------------------
diff --git 
a/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleWriter.scala 
b/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleWriter.scala
index 41df70c..412bf70 100644
--- a/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleWriter.scala
+++ b/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleWriter.scala
@@ -17,6 +17,8 @@
 
 package org.apache.spark.shuffle.hash
 
+import java.io.IOException
+
 import org.apache.spark._
 import org.apache.spark.executor.ShuffleWriteMetrics
 import org.apache.spark.scheduler.MapStatus
@@ -106,6 +108,29 @@ private[spark] class HashShuffleWriter[K, V](
       writer.commitAndClose()
       writer.fileSegment().length
     }
+    // rename all shuffle files to final paths
+    // Note: there is only one ShuffleBlockResolver in executor
+    shuffleBlockResolver.synchronized {
+      shuffle.writers.zipWithIndex.foreach { case (writer, i) =>
+        val output = blockManager.diskBlockManager.getFile(writer.blockId)
+        if (sizes(i) > 0) {
+          if (output.exists()) {
+            // Use length of existing file and delete our own temporary one
+            sizes(i) = output.length()
+            writer.file.delete()
+          } else {
+            // Commit by renaming our temporary file to something the fetcher 
expects
+            if (!writer.file.renameTo(output)) {
+              throw new IOException(s"fail to rename ${writer.file} to 
$output")
+            }
+          }
+        } else {
+          if (output.exists()) {
+            output.delete()
+          }
+        }
+      }
+    }
     MapStatus(blockManager.shuffleServerId, sizes)
   }
 

http://git-wip-us.apache.org/repos/asf/spark/blob/ad960885/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleWriter.scala
----------------------------------------------------------------------
diff --git 
a/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleWriter.scala 
b/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleWriter.scala
index 808317b..f83cf88 100644
--- a/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleWriter.scala
+++ b/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleWriter.scala
@@ -20,8 +20,9 @@ package org.apache.spark.shuffle.sort
 import org.apache.spark._
 import org.apache.spark.executor.ShuffleWriteMetrics
 import org.apache.spark.scheduler.MapStatus
-import org.apache.spark.shuffle.{IndexShuffleBlockResolver, ShuffleWriter, 
BaseShuffleHandle}
+import org.apache.spark.shuffle.{BaseShuffleHandle, IndexShuffleBlockResolver, 
ShuffleWriter}
 import org.apache.spark.storage.ShuffleBlockId
+import org.apache.spark.util.Utils
 import org.apache.spark.util.collection.ExternalSorter
 
 private[spark] class SortShuffleWriter[K, V, C](
@@ -65,11 +66,11 @@ private[spark] class SortShuffleWriter[K, V, C](
     // Don't bother including the time to open the merged output file in the 
shuffle write time,
     // because it just opens a single file, so is typically too fast to 
measure accurately
     // (see SPARK-3570).
-    val outputFile = shuffleBlockResolver.getDataFile(dep.shuffleId, mapId)
+    val output = shuffleBlockResolver.getDataFile(dep.shuffleId, mapId)
+    val tmp = Utils.tempFileWith(output)
     val blockId = ShuffleBlockId(dep.shuffleId, mapId, 
IndexShuffleBlockResolver.NOOP_REDUCE_ID)
-    val partitionLengths = sorter.writePartitionedFile(blockId, outputFile)
-    shuffleBlockResolver.writeIndexFile(dep.shuffleId, mapId, partitionLengths)
-
+    val partitionLengths = sorter.writePartitionedFile(blockId, tmp)
+    shuffleBlockResolver.writeIndexFileAndCommit(dep.shuffleId, mapId, 
partitionLengths, tmp)
     mapStatus = MapStatus(blockManager.shuffleServerId, partitionLengths)
   }
 

http://git-wip-us.apache.org/repos/asf/spark/blob/ad960885/core/src/main/scala/org/apache/spark/storage/BlockManager.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala 
b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala
index c374b93..661c706 100644
--- a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala
+++ b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala
@@ -21,10 +21,10 @@ import java.io._
 import java.nio.{ByteBuffer, MappedByteBuffer}
 
 import scala.collection.mutable.{ArrayBuffer, HashMap}
-import scala.concurrent.{ExecutionContext, Await, Future}
 import scala.concurrent.duration._
-import scala.util.control.NonFatal
+import scala.concurrent.{Await, ExecutionContext, Future}
 import scala.util.Random
+import scala.util.control.NonFatal
 
 import sun.nio.ch.DirectBuffer
 
@@ -38,9 +38,8 @@ import org.apache.spark.network.netty.SparkTransportConf
 import org.apache.spark.network.shuffle.ExternalShuffleClient
 import org.apache.spark.network.shuffle.protocol.ExecutorShuffleInfo
 import org.apache.spark.rpc.RpcEnv
-import org.apache.spark.serializer.{SerializerInstance, Serializer}
+import org.apache.spark.serializer.{Serializer, SerializerInstance}
 import org.apache.spark.shuffle.ShuffleManager
-import org.apache.spark.shuffle.hash.HashShuffleManager
 import org.apache.spark.util._
 
 private[spark] sealed trait BlockValues
@@ -660,7 +659,7 @@ private[spark] class BlockManager(
     val compressStream: OutputStream => OutputStream = 
wrapForCompression(blockId, _)
     val syncWrites = conf.getBoolean("spark.shuffle.sync", false)
     new DiskBlockObjectWriter(file, serializerInstance, bufferSize, 
compressStream,
-      syncWrites, writeMetrics)
+      syncWrites, writeMetrics, blockId)
   }
 
   /**

http://git-wip-us.apache.org/repos/asf/spark/blob/ad960885/core/src/main/scala/org/apache/spark/storage/DiskBlockObjectWriter.scala
----------------------------------------------------------------------
diff --git 
a/core/src/main/scala/org/apache/spark/storage/DiskBlockObjectWriter.scala 
b/core/src/main/scala/org/apache/spark/storage/DiskBlockObjectWriter.scala
index 80d426f..e2dd80f 100644
--- a/core/src/main/scala/org/apache/spark/storage/DiskBlockObjectWriter.scala
+++ b/core/src/main/scala/org/apache/spark/storage/DiskBlockObjectWriter.scala
@@ -34,14 +34,15 @@ import org.apache.spark.util.Utils
  * reopened again.
  */
 private[spark] class DiskBlockObjectWriter(
-    file: File,
+    val file: File,
     serializerInstance: SerializerInstance,
     bufferSize: Int,
     compressStream: OutputStream => OutputStream,
     syncWrites: Boolean,
     // These write metrics concurrently shared with other active 
DiskBlockObjectWriters who
     // are themselves performing writes. All updates must be relative.
-    writeMetrics: ShuffleWriteMetrics)
+    writeMetrics: ShuffleWriteMetrics,
+    val blockId: BlockId = null)
   extends OutputStream
   with Logging {
 

http://git-wip-us.apache.org/repos/asf/spark/blob/ad960885/core/src/main/scala/org/apache/spark/util/Utils.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/util/Utils.scala 
b/core/src/main/scala/org/apache/spark/util/Utils.scala
index 316c194..1b3acb8 100644
--- a/core/src/main/scala/org/apache/spark/util/Utils.scala
+++ b/core/src/main/scala/org/apache/spark/util/Utils.scala
@@ -21,8 +21,8 @@ import java.io._
 import java.lang.management.ManagementFactory
 import java.net._
 import java.nio.ByteBuffer
-import java.util.{Properties, Locale, Random, UUID}
 import java.util.concurrent._
+import java.util.{Locale, Properties, Random, UUID}
 import javax.net.ssl.HttpsURLConnection
 
 import scala.collection.JavaConverters._
@@ -30,7 +30,7 @@ import scala.collection.Map
 import scala.collection.mutable.ArrayBuffer
 import scala.io.Source
 import scala.reflect.ClassTag
-import scala.util.{Failure, Success, Try}
+import scala.util.Try
 import scala.util.control.{ControlThrowable, NonFatal}
 
 import com.google.common.io.{ByteStreams, Files}
@@ -42,7 +42,6 @@ import org.apache.hadoop.security.UserGroupInformation
 import org.apache.log4j.PropertyConfigurator
 import org.eclipse.jetty.util.MultiException
 import org.json4s._
-
 import tachyon.TachyonURI
 import tachyon.client.{TachyonFS, TachyonFile}
 
@@ -2169,6 +2168,13 @@ private[spark] object Utils extends Logging {
     val resource = createResource
     try f.apply(resource) finally resource.close()
   }
+
+  /**
+   * Returns a path of temporary file which is in the same directory with 
`path`.
+   */
+  def tempFileWith(path: File): File = {
+    new File(path.getAbsolutePath + "." + UUID.randomUUID())
+  }
 }
 
 /**

http://git-wip-us.apache.org/repos/asf/spark/blob/ad960885/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala
----------------------------------------------------------------------
diff --git 
a/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala 
b/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala
index bd6844d..2440139 100644
--- a/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala
+++ b/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala
@@ -638,7 +638,6 @@ private[spark] class ExternalSorter[K, V, C](
    * called by the SortShuffleWriter.
    *
    * @param blockId block ID to write to. The index file will be blockId.name 
+ ".index".
-   * @param context a TaskContext for a running Spark task, for us to update 
shuffle metrics.
    * @return array of lengths, in bytes, of each partition of the file (used 
by map output tracker)
    */
   def writePartitionedFile(

http://git-wip-us.apache.org/repos/asf/spark/blob/ad960885/core/src/test/java/org/apache/spark/shuffle/sort/IndexShuffleBlockResolverSuite.scala
----------------------------------------------------------------------
diff --git 
a/core/src/test/java/org/apache/spark/shuffle/sort/IndexShuffleBlockResolverSuite.scala
 
b/core/src/test/java/org/apache/spark/shuffle/sort/IndexShuffleBlockResolverSuite.scala
new file mode 100644
index 0000000..0b19861
--- /dev/null
+++ 
b/core/src/test/java/org/apache/spark/shuffle/sort/IndexShuffleBlockResolverSuite.scala
@@ -0,0 +1,114 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.shuffle.sort
+
+import java.io.{File, FileInputStream, FileOutputStream}
+
+import org.mockito.Answers.RETURNS_SMART_NULLS
+import org.mockito.Matchers._
+import org.mockito.Mockito._
+import org.mockito.invocation.InvocationOnMock
+import org.mockito.stubbing.Answer
+import org.mockito.{Mock, MockitoAnnotations}
+import org.scalatest.BeforeAndAfterEach
+
+import org.apache.spark.shuffle.IndexShuffleBlockResolver
+import org.apache.spark.storage._
+import org.apache.spark.util.Utils
+import org.apache.spark.{SparkConf, SparkFunSuite}
+
+
+class IndexShuffleBlockResolverSuite extends SparkFunSuite with 
BeforeAndAfterEach {
+
+  @Mock(answer = RETURNS_SMART_NULLS) private var blockManager: BlockManager = 
_
+  @Mock(answer = RETURNS_SMART_NULLS) private var diskBlockManager: 
DiskBlockManager = _
+
+  private var tempDir: File = _
+  private val conf: SparkConf = new SparkConf(loadDefaults = false)
+
+  override def beforeEach(): Unit = {
+    tempDir = Utils.createTempDir()
+    MockitoAnnotations.initMocks(this)
+
+    when(blockManager.diskBlockManager).thenReturn(diskBlockManager)
+    when(diskBlockManager.getFile(any[BlockId])).thenAnswer(
+      new Answer[File] {
+        override def answer(invocation: InvocationOnMock): File = {
+          new File(tempDir, invocation.getArguments.head.toString)
+        }
+      })
+  }
+
+  override def afterEach(): Unit = {
+    Utils.deleteRecursively(tempDir)
+  }
+
+  test("commit shuffle files multiple times") {
+    val lengths = Array[Long](10, 0, 20)
+    val resolver = new IndexShuffleBlockResolver(conf, blockManager)
+    val dataTmp = File.createTempFile("shuffle", null, tempDir)
+    val out = new FileOutputStream(dataTmp)
+    out.write(new Array[Byte](30))
+    out.close()
+    resolver.writeIndexFileAndCommit(1, 2, lengths, dataTmp)
+
+    val dataFile = resolver.getDataFile(1, 2)
+    assert(dataFile.exists())
+    assert(dataFile.length() === 30)
+    assert(!dataTmp.exists())
+
+    val dataTmp2 = File.createTempFile("shuffle", null, tempDir)
+    val out2 = new FileOutputStream(dataTmp2)
+    val lengths2 = new Array[Long](3)
+    out2.write(Array[Byte](1))
+    out2.write(new Array[Byte](29))
+    out2.close()
+    resolver.writeIndexFileAndCommit(1, 2, lengths2, dataTmp2)
+    assert(lengths2.toSeq === lengths.toSeq)
+    assert(dataFile.exists())
+    assert(dataFile.length() === 30)
+    assert(!dataTmp2.exists())
+
+    // The dataFile should be the previous one
+    val in = new FileInputStream(dataFile)
+    val firstByte = new Array[Byte](1)
+    in.read(firstByte)
+    assert(firstByte(0) === 0)
+
+    // remove data file
+    dataFile.delete()
+
+    val dataTmp3 = File.createTempFile("shuffle", null, tempDir)
+    val out3 = new FileOutputStream(dataTmp3)
+    val lengths3 = Array[Long](10, 10, 15)
+    out3.write(Array[Byte](2))
+    out3.write(new Array[Byte](34))
+    out3.close()
+    resolver.writeIndexFileAndCommit(1, 2, lengths3, dataTmp3)
+    assert(lengths3.toSeq != lengths.toSeq)
+    assert(dataFile.exists())
+    assert(dataFile.length() === 35)
+    assert(!dataTmp2.exists())
+
+    // The dataFile should be the previous one
+    val in2 = new FileInputStream(dataFile)
+    val firstByte2 = new Array[Byte](1)
+    in2.read(firstByte2)
+    assert(firstByte2(0) === 2)
+  }
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/ad960885/core/src/test/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriterSuite.java
----------------------------------------------------------------------
diff --git 
a/core/src/test/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriterSuite.java
 
b/core/src/test/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriterSuite.java
index 0e0eca5..bc85918 100644
--- 
a/core/src/test/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriterSuite.java
+++ 
b/core/src/test/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriterSuite.java
@@ -130,7 +130,8 @@ public class UnsafeShuffleWriterSuite {
           (Integer) args[3],
           new CompressStream(),
           false,
-          (ShuffleWriteMetrics) args[4]
+          (ShuffleWriteMetrics) args[4],
+          (BlockId) args[0]
         );
       }
     });
@@ -169,9 +170,13 @@ public class UnsafeShuffleWriterSuite {
       @Override
       public Void answer(InvocationOnMock invocationOnMock) throws Throwable {
         partitionSizesInMergedFile = (long[]) 
invocationOnMock.getArguments()[2];
+        File tmp = (File) invocationOnMock.getArguments()[3];
+        mergedOutputFile.delete();
+        tmp.renameTo(mergedOutputFile);
         return null;
       }
-    }).when(shuffleBlockResolver).writeIndexFile(anyInt(), anyInt(), 
any(long[].class));
+    }).when(shuffleBlockResolver)
+      .writeIndexFileAndCommit(anyInt(), anyInt(), any(long[].class), 
any(File.class));
 
     when(diskBlockManager.createTempShuffleBlock()).thenAnswer(
       new Answer<Tuple2<TempShuffleBlockId, File>>() {

http://git-wip-us.apache.org/repos/asf/spark/blob/ad960885/core/src/test/java/org/apache/spark/unsafe/map/AbstractBytesToBytesMapSuite.java
----------------------------------------------------------------------
diff --git 
a/core/src/test/java/org/apache/spark/unsafe/map/AbstractBytesToBytesMapSuite.java
 
b/core/src/test/java/org/apache/spark/unsafe/map/AbstractBytesToBytesMapSuite.java
index 3bca790..d87a1d2 100644
--- 
a/core/src/test/java/org/apache/spark/unsafe/map/AbstractBytesToBytesMapSuite.java
+++ 
b/core/src/test/java/org/apache/spark/unsafe/map/AbstractBytesToBytesMapSuite.java
@@ -117,7 +117,8 @@ public abstract class AbstractBytesToBytesMapSuite {
           (Integer) args[3],
           new CompressStream(),
           false,
-          (ShuffleWriteMetrics) args[4]
+          (ShuffleWriteMetrics) args[4],
+          (BlockId) args[0]
         );
       }
     });

http://git-wip-us.apache.org/repos/asf/spark/blob/ad960885/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorterSuite.java
----------------------------------------------------------------------
diff --git 
a/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorterSuite.java
 
b/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorterSuite.java
index 11c3a7b..a1c9f6f 100644
--- 
a/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorterSuite.java
+++ 
b/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorterSuite.java
@@ -130,7 +130,8 @@ public class UnsafeExternalSorterSuite {
           (Integer) args[3],
           new CompressStream(),
           false,
-          (ShuffleWriteMetrics) args[4]
+          (ShuffleWriteMetrics) args[4],
+          (BlockId) args[0]
         );
       }
     });

http://git-wip-us.apache.org/repos/asf/spark/blob/ad960885/core/src/test/scala/org/apache/spark/ShuffleSuite.scala
----------------------------------------------------------------------
diff --git a/core/src/test/scala/org/apache/spark/ShuffleSuite.scala 
b/core/src/test/scala/org/apache/spark/ShuffleSuite.scala
index 4a0877d..0de10ae 100644
--- a/core/src/test/scala/org/apache/spark/ShuffleSuite.scala
+++ b/core/src/test/scala/org/apache/spark/ShuffleSuite.scala
@@ -17,12 +17,16 @@
 
 package org.apache.spark
 
+import java.util.concurrent.{Callable, Executors, ExecutorService, 
CyclicBarrier}
+
 import org.scalatest.Matchers
 
 import org.apache.spark.ShuffleSuite.NonJavaSerializableClass
+import org.apache.spark.memory.TaskMemoryManager
 import org.apache.spark.rdd.{CoGroupedRDD, OrderedRDDFunctions, RDD, 
ShuffledRDD, SubtractedRDD}
-import org.apache.spark.scheduler.{SparkListener, SparkListenerTaskEnd}
+import org.apache.spark.scheduler.{MyRDD, MapStatus, SparkListener, 
SparkListenerTaskEnd}
 import org.apache.spark.serializer.KryoSerializer
+import org.apache.spark.shuffle.ShuffleWriter
 import org.apache.spark.storage.{ShuffleDataBlockId, ShuffleBlockId}
 import org.apache.spark.util.MutablePair
 
@@ -317,6 +321,107 @@ abstract class ShuffleSuite extends SparkFunSuite with 
Matchers with LocalSparkC
     assert(metrics.bytesWritten === metrics.byresRead)
     assert(metrics.bytesWritten > 0)
   }
+
+  test("multiple simultaneous attempts for one task (SPARK-8029)") {
+    sc = new SparkContext("local", "test", conf)
+    val mapTrackerMaster = 
sc.env.mapOutputTracker.asInstanceOf[MapOutputTrackerMaster]
+    val manager = sc.env.shuffleManager
+
+    val taskMemoryManager = new TaskMemoryManager(sc.env.memoryManager, 0L)
+    val metricsSystem = sc.env.metricsSystem
+    val shuffleMapRdd = new MyRDD(sc, 1, Nil)
+    val shuffleDep = new ShuffleDependency(shuffleMapRdd, new 
HashPartitioner(1))
+    val shuffleHandle = manager.registerShuffle(0, 1, shuffleDep)
+
+    // first attempt -- its successful
+    val writer1 = manager.getWriter[Int, Int](shuffleHandle, 0,
+      new TaskContextImpl(0, 0, 0L, 0, taskMemoryManager, metricsSystem,
+        InternalAccumulator.create(sc)))
+    val data1 = (1 to 10).map { x => x -> x}
+
+    // second attempt -- also successful.  We'll write out different data,
+    // just to simulate the fact that the records may get written differently
+    // depending on what gets spilled, what gets combined, etc.
+    val writer2 = manager.getWriter[Int, Int](shuffleHandle, 0,
+      new TaskContextImpl(0, 0, 1L, 0, taskMemoryManager, metricsSystem,
+        InternalAccumulator.create(sc)))
+    val data2 = (11 to 20).map { x => x -> x}
+
+    // interleave writes of both attempts -- we want to test that both 
attempts can occur
+    // simultaneously, and everything is still OK
+
+    def writeAndClose(
+      writer: ShuffleWriter[Int, Int])(
+      iter: Iterator[(Int, Int)]): Option[MapStatus] = {
+      val files = writer.write(iter)
+      writer.stop(true)
+    }
+    val interleaver = new InterleaveIterators(
+      data1, writeAndClose(writer1), data2, writeAndClose(writer2))
+    val (mapOutput1, mapOutput2) = interleaver.run()
+
+    // check that we can read the map output and it has the right data
+    assert(mapOutput1.isDefined)
+    assert(mapOutput2.isDefined)
+    assert(mapOutput1.get.location === mapOutput2.get.location)
+    assert(mapOutput1.get.getSizeForBlock(0) === 
mapOutput1.get.getSizeForBlock(0))
+
+    // register one of the map outputs -- doesn't matter which one
+    mapOutput1.foreach { case mapStatus =>
+      mapTrackerMaster.registerMapOutputs(0, Array(mapStatus))
+    }
+
+    val reader = manager.getReader[Int, Int](shuffleHandle, 0, 1,
+      new TaskContextImpl(1, 0, 2L, 0, taskMemoryManager, metricsSystem,
+        InternalAccumulator.create(sc)))
+    val readData = reader.read().toIndexedSeq
+    assert(readData === data1.toIndexedSeq || readData === data2.toIndexedSeq)
+
+    manager.unregisterShuffle(0)
+  }
+}
+
+/**
+ * Utility to help tests make sure that we can process two different iterators 
simultaneously
+ * in different threads.  This makes sure that in your test, you don't 
completely process data1 with
+ * f1 before processing data2 with f2 (or vice versa).  It adds a barrier so 
that the functions only
+ * process one element, before pausing to wait for the other function to 
"catch up".
+ */
+class InterleaveIterators[T, R](
+  data1: Seq[T],
+  f1: Iterator[T] => R,
+  data2: Seq[T],
+  f2: Iterator[T] => R) {
+
+  require(data1.size == data2.size)
+
+  val barrier = new CyclicBarrier(2)
+  class BarrierIterator[E](id: Int, sub: Iterator[E]) extends Iterator[E] {
+    def hasNext: Boolean = sub.hasNext
+
+    def next: E = {
+      barrier.await()
+      sub.next()
+    }
+  }
+
+  val c1 = new Callable[R] {
+    override def call(): R = f1(new BarrierIterator(1, data1.iterator))
+  }
+  val c2 = new Callable[R] {
+    override def call(): R = f2(new BarrierIterator(2, data2.iterator))
+  }
+
+  val e: ExecutorService = Executors.newFixedThreadPool(2)
+
+  def run(): (R, R) = {
+    val future1 = e.submit(c1)
+    val future2 = e.submit(c2)
+    val r1 = future1.get()
+    val r2 = future2.get()
+    e.shutdown()
+    (r1, r2)
+  }
 }
 
 object ShuffleSuite {

http://git-wip-us.apache.org/repos/asf/spark/blob/ad960885/core/src/test/scala/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriterSuite.scala
----------------------------------------------------------------------
diff --git 
a/core/src/test/scala/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriterSuite.scala
 
b/core/src/test/scala/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriterSuite.scala
index b92a302..d3b1b2b 100644
--- 
a/core/src/test/scala/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriterSuite.scala
+++ 
b/core/src/test/scala/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriterSuite.scala
@@ -68,6 +68,17 @@ class BypassMergeSortShuffleWriterSuite extends 
SparkFunSuite with BeforeAndAfte
     when(dependency.serializer).thenReturn(Some(new JavaSerializer(conf)))
     when(taskContext.taskMetrics()).thenReturn(taskMetrics)
     when(blockResolver.getDataFile(0, 0)).thenReturn(outputFile)
+    doAnswer(new Answer[Void] {
+      def answer(invocationOnMock: InvocationOnMock): Void = {
+        val tmp: File = invocationOnMock.getArguments()(3).asInstanceOf[File]
+        if (tmp != null) {
+          outputFile.delete
+          tmp.renameTo(outputFile)
+        }
+        null
+      }
+    }).when(blockResolver)
+      .writeIndexFileAndCommit(anyInt, anyInt, any(classOf[Array[Long]]), 
any(classOf[File]))
     when(blockManager.diskBlockManager).thenReturn(diskBlockManager)
     when(blockManager.getDiskWriter(
       any[BlockId],
@@ -84,7 +95,8 @@ class BypassMergeSortShuffleWriterSuite extends SparkFunSuite 
with BeforeAndAfte
           args(3).asInstanceOf[Int],
           compressStream = identity,
           syncWrites = false,
-          args(4).asInstanceOf[ShuffleWriteMetrics]
+          args(4).asInstanceOf[ShuffleWriteMetrics],
+          blockId = args(0).asInstanceOf[BlockId]
         )
       }
     })


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to