Repository: spark
Updated Branches:
  refs/heads/branch-2.2 124789b62 -> d7c3aae20


[SPARK-23207][SPARK-22905][SPARK-24564][SPARK-25114][SQL][BACKPORT-2.2] 
Shuffle+Repartition on a DataFrame could lead to incorrect answers

## What changes were proposed in this pull request?

Back port of #20393.

Currently shuffle repartition uses RoundRobinPartitioning, the generated result 
is nondeterministic since the sequence of input rows are not determined.

The bug can be triggered when there is a repartition call following a shuffle 
(which would lead to non-deterministic row ordering), as the pattern shows 
below:
upstream stage -> repartition stage -> result stage
(-> indicate a shuffle)
When one of the executors process goes down, some tasks on the repartition 
stage will be retried and generate inconsistent ordering, and some tasks of the 
result stage will be retried generating different data.

The following code returns 931532, instead of 1000000:
```
import scala.sys.process._

import org.apache.spark.TaskContext
val res = spark.range(0, 1000 * 1000, 1).repartition(200).map { x =>
  x
}.repartition(200).map { x =>
  if (TaskContext.get.attemptNumber == 0 && TaskContext.get.partitionId < 2) {
    throw new Exception("pkill -f java".!!)
  }
  x
}
res.distinct().count()
```

In this PR, we propose a most straight-forward way to fix this problem by 
performing a local sort before partitioning, after we make the input row 
ordering deterministic, the function from rows to partitions is fully 
deterministic too.

The downside of the approach is that with extra local sort inserted, the 
performance of repartition() will go down, so we add a new config named 
`spark.sql.execution.sortBeforeRepartition` to control whether this patch is 
applied. The patch is default enabled to be safe-by-default, but user may 
choose to manually turn it off to avoid performance regression.

This patch also changes the output rows ordering of repartition(), that leads 
to a bunch of test cases failure because they are comparing the results 
directly.

Add unit test in ExchangeSuite.

With this patch(and `spark.sql.execution.sortBeforeRepartition` set to true), 
the following query returns 1000000:
```
import scala.sys.process._

import org.apache.spark.TaskContext

spark.conf.set("spark.sql.execution.sortBeforeRepartition", "true")

val res = spark.range(0, 1000 * 1000, 1).repartition(200).map { x =>
  x
}.repartition(200).map { x =>
  if (TaskContext.get.attemptNumber == 0 && TaskContext.get.partitionId < 2) {
    throw new Exception("pkill -f java".!!)
  }
  x
}
res.distinct().count()

res7: Long = 1000000
```

Author: Xingbo Jiang <xingbo.jiangdatabricks.com>

## How was this patch tested?

Ran all SBT unit tests for org.apache.spark.sql.*.

Ran pyspark tests for module pyspark-sql.

Closes #22079 from bersprockets/SPARK-23207.

Lead-authored-by: Xingbo Jiang <xingbo.ji...@databricks.com>
Co-authored-by: Bruce Robbins <bersprock...@gmail.com>
Co-authored-by: Zheng RuiFeng <ruife...@foxmail.com>
Signed-off-by: Xiao Li <gatorsm...@gmail.com>


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/d7c3aae2
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/d7c3aae2
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/d7c3aae2

Branch: refs/heads/branch-2.2
Commit: d7c3aae2074b3dd3923dd754c0a3c97308c66893
Parents: 124789b
Author: Xingbo Jiang <xingbo.ji...@databricks.com>
Authored: Thu Aug 23 14:22:56 2018 -0700
Committer: Xiao Li <gatorsm...@gmail.com>
Committed: Thu Aug 23 14:22:56 2018 -0700

----------------------------------------------------------------------
 .../unsafe/sort/RecordComparator.java           |   4 +-
 .../unsafe/sort/UnsafeInMemorySorter.java       |   7 +-
 .../unsafe/sort/UnsafeSorterSpillMerger.java    |   4 +-
 .../main/scala/org/apache/spark/rdd/RDD.scala   |   2 +
 .../apache/spark/memory/TestMemoryConsumer.java |  10 +
 .../unsafe/sort/UnsafeExternalSorterSuite.java  |   4 +-
 .../unsafe/sort/UnsafeInMemorySorterSuite.java  |   8 +-
 .../mllib/clustering/GaussianMixtureModel.scala |   2 +-
 .../spark/mllib/feature/ChiSqSelector.scala     |   2 +-
 .../apache/spark/ml/feature/Word2VecSuite.scala |   3 +-
 .../sql/execution/RecordBinaryComparator.java   |  74 +++++
 .../sql/execution/UnsafeExternalRowSorter.java  |  46 ++-
 .../org/apache/spark/sql/internal/SQLConf.scala |  14 +
 .../sql/execution/UnsafeKVExternalSorter.java   |  10 +-
 .../apache/spark/sql/execution/SortExec.scala   |   2 +-
 .../execution/exchange/ShuffleExchange.scala    |  55 +++-
 .../sort/RecordBinaryComparatorSuite.java       | 322 +++++++++++++++++++
 .../spark/sql/execution/ExchangeSuite.scala     |  26 +-
 .../datasources/parquet/ParquetIOSuite.scala    |   6 +-
 .../execution/streaming/ForeachSinkSuite.scala  |   4 +-
 20 files changed, 575 insertions(+), 30 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/d7c3aae2/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/RecordComparator.java
----------------------------------------------------------------------
diff --git 
a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/RecordComparator.java
 
b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/RecordComparator.java
index 09e4258..02b5de8 100644
--- 
a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/RecordComparator.java
+++ 
b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/RecordComparator.java
@@ -32,6 +32,8 @@ public abstract class RecordComparator {
   public abstract int compare(
     Object leftBaseObject,
     long leftBaseOffset,
+    int leftBaseLength,
     Object rightBaseObject,
-    long rightBaseOffset);
+    long rightBaseOffset,
+    int rightBaseLength);
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/d7c3aae2/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorter.java
----------------------------------------------------------------------
diff --git 
a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorter.java
 
b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorter.java
index 869ec90..839b41d 100644
--- 
a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorter.java
+++ 
b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorter.java
@@ -61,12 +61,13 @@ public final class UnsafeInMemorySorter {
       int uaoSize = UnsafeAlignedOffset.getUaoSize();
       if (prefixComparisonResult == 0) {
         final Object baseObject1 = memoryManager.getPage(r1.recordPointer);
-        // skip length
         final long baseOffset1 = 
memoryManager.getOffsetInPage(r1.recordPointer) + uaoSize;
+        final int baseLength1 = UnsafeAlignedOffset.getSize(baseObject1, 
baseOffset1 - uaoSize);
         final Object baseObject2 = memoryManager.getPage(r2.recordPointer);
-        // skip length
         final long baseOffset2 = 
memoryManager.getOffsetInPage(r2.recordPointer) + uaoSize;
-        return recordComparator.compare(baseObject1, baseOffset1, baseObject2, 
baseOffset2);
+        final int baseLength2 = UnsafeAlignedOffset.getSize(baseObject2, 
baseOffset2 - uaoSize);
+        return recordComparator.compare(baseObject1, baseOffset1, baseLength1, 
baseObject2,
+          baseOffset2, baseLength2);
       } else {
         return prefixComparisonResult;
       }

http://git-wip-us.apache.org/repos/asf/spark/blob/d7c3aae2/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillMerger.java
----------------------------------------------------------------------
diff --git 
a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillMerger.java
 
b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillMerger.java
index cf4dfde..ff0dcc2 100644
--- 
a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillMerger.java
+++ 
b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillMerger.java
@@ -35,8 +35,8 @@ final class UnsafeSorterSpillMerger {
         prefixComparator.compare(left.getKeyPrefix(), right.getKeyPrefix());
       if (prefixComparisonResult == 0) {
         return recordComparator.compare(
-          left.getBaseObject(), left.getBaseOffset(),
-          right.getBaseObject(), right.getBaseOffset());
+          left.getBaseObject(), left.getBaseOffset(), left.getRecordLength(),
+          right.getBaseObject(), right.getBaseOffset(), 
right.getRecordLength());
       } else {
         return prefixComparisonResult;
       }

http://git-wip-us.apache.org/repos/asf/spark/blob/d7c3aae2/core/src/main/scala/org/apache/spark/rdd/RDD.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/rdd/RDD.scala 
b/core/src/main/scala/org/apache/spark/rdd/RDD.scala
index 63a87e7..102836d 100644
--- a/core/src/main/scala/org/apache/spark/rdd/RDD.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/RDD.scala
@@ -413,6 +413,8 @@ abstract class RDD[T: ClassTag](
    *
    * If you are decreasing the number of partitions in this RDD, consider 
using `coalesce`,
    * which can avoid performing a shuffle.
+   *
+   * TODO Fix the Shuffle+Repartition data loss issue described in SPARK-23207.
    */
   def repartition(numPartitions: Int)(implicit ord: Ordering[T] = null): 
RDD[T] = withScope {
     coalesce(numPartitions, shuffle = true)

http://git-wip-us.apache.org/repos/asf/spark/blob/d7c3aae2/core/src/test/java/org/apache/spark/memory/TestMemoryConsumer.java
----------------------------------------------------------------------
diff --git a/core/src/test/java/org/apache/spark/memory/TestMemoryConsumer.java 
b/core/src/test/java/org/apache/spark/memory/TestMemoryConsumer.java
index db91329..0bbaea6 100644
--- a/core/src/test/java/org/apache/spark/memory/TestMemoryConsumer.java
+++ b/core/src/test/java/org/apache/spark/memory/TestMemoryConsumer.java
@@ -17,6 +17,10 @@
 
 package org.apache.spark.memory;
 
+import com.google.common.annotations.VisibleForTesting;
+
+import org.apache.spark.unsafe.memory.MemoryBlock;
+
 import java.io.IOException;
 
 public class TestMemoryConsumer extends MemoryConsumer {
@@ -43,6 +47,12 @@ public class TestMemoryConsumer extends MemoryConsumer {
     used -= size;
     taskMemoryManager.releaseExecutionMemory(size, this);
   }
+
+  @VisibleForTesting
+  public void freePage(MemoryBlock page) {
+    used -= page.size();
+    taskMemoryManager.freePage(page, this);
+  }
 }
 
 

http://git-wip-us.apache.org/repos/asf/spark/blob/d7c3aae2/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorterSuite.java
----------------------------------------------------------------------
diff --git 
a/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorterSuite.java
 
b/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorterSuite.java
index 8d847da..cce01a3 100644
--- 
a/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorterSuite.java
+++ 
b/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorterSuite.java
@@ -71,8 +71,10 @@ public class UnsafeExternalSorterSuite {
     public int compare(
       Object leftBaseObject,
       long leftBaseOffset,
+      int leftBaseLength,
       Object rightBaseObject,
-      long rightBaseOffset) {
+      long rightBaseOffset,
+      int rightBaseLength) {
       return 0;
     }
   };

http://git-wip-us.apache.org/repos/asf/spark/blob/d7c3aae2/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorterSuite.java
----------------------------------------------------------------------
diff --git 
a/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorterSuite.java
 
b/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorterSuite.java
index 1a3e11e..cfb0030 100644
--- 
a/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorterSuite.java
+++ 
b/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorterSuite.java
@@ -98,8 +98,10 @@ public class UnsafeInMemorySorterSuite {
       public int compare(
         Object leftBaseObject,
         long leftBaseOffset,
+        int leftBaseLength,
         Object rightBaseObject,
-        long rightBaseOffset) {
+        long rightBaseOffset,
+        int rightBaseLength) {
         return 0;
       }
     };
@@ -164,8 +166,10 @@ public class UnsafeInMemorySorterSuite {
       public int compare(
               Object leftBaseObject,
               long leftBaseOffset,
+              int leftBaseLength,
               Object rightBaseObject,
-              long rightBaseOffset) {
+              long rightBaseOffset,
+              int rightBaseLength) {
         return 0;
       }
     };

http://git-wip-us.apache.org/repos/asf/spark/blob/d7c3aae2/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixtureModel.scala
----------------------------------------------------------------------
diff --git 
a/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixtureModel.scala
 
b/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixtureModel.scala
index afbe4f9..1933d54 100644
--- 
a/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixtureModel.scala
+++ 
b/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixtureModel.scala
@@ -154,7 +154,7 @@ object GaussianMixtureModel extends 
Loader[GaussianMixtureModel] {
       val dataArray = Array.tabulate(weights.length) { i =>
         Data(weights(i), gaussians(i).mu, gaussians(i).sigma)
       }
-      
spark.createDataFrame(dataArray).repartition(1).write.parquet(Loader.dataPath(path))
+      spark.createDataFrame(sc.makeRDD(dataArray, 
1)).write.parquet(Loader.dataPath(path))
     }
 
     def load(sc: SparkContext, path: String): GaussianMixtureModel = {

http://git-wip-us.apache.org/repos/asf/spark/blob/d7c3aae2/mllib/src/main/scala/org/apache/spark/mllib/feature/ChiSqSelector.scala
----------------------------------------------------------------------
diff --git 
a/mllib/src/main/scala/org/apache/spark/mllib/feature/ChiSqSelector.scala 
b/mllib/src/main/scala/org/apache/spark/mllib/feature/ChiSqSelector.scala
index 862be6f..015cc9f 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/feature/ChiSqSelector.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/feature/ChiSqSelector.scala
@@ -144,7 +144,7 @@ object ChiSqSelectorModel extends 
Loader[ChiSqSelectorModel] {
       val dataArray = Array.tabulate(model.selectedFeatures.length) { i =>
         Data(model.selectedFeatures(i))
       }
-      
spark.createDataFrame(dataArray).repartition(1).write.parquet(Loader.dataPath(path))
+      spark.createDataFrame(sc.makeRDD(dataArray, 
1)).write.parquet(Loader.dataPath(path))
     }
 
     def load(sc: SparkContext, path: String): ChiSqSelectorModel = {

http://git-wip-us.apache.org/repos/asf/spark/blob/d7c3aae2/mllib/src/test/scala/org/apache/spark/ml/feature/Word2VecSuite.scala
----------------------------------------------------------------------
diff --git 
a/mllib/src/test/scala/org/apache/spark/ml/feature/Word2VecSuite.scala 
b/mllib/src/test/scala/org/apache/spark/ml/feature/Word2VecSuite.scala
index 6183606..10682ba 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/feature/Word2VecSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/feature/Word2VecSuite.scala
@@ -222,7 +222,8 @@ class Word2VecSuite extends SparkFunSuite with 
MLlibTestSparkContext with Defaul
     val oldModel = new OldWord2VecModel(word2VecMap)
     val instance = new Word2VecModel("myWord2VecModel", oldModel)
     val newInstance = testDefaultReadWrite(instance)
-    assert(newInstance.getVectors.collect() === instance.getVectors.collect())
+    assert(newInstance.getVectors.collect().sortBy(_.getString(0)) ===
+      instance.getVectors.collect().sortBy(_.getString(0)))
   }
 
   test("Word2Vec works with input that is non-nullable (NGram)") {

http://git-wip-us.apache.org/repos/asf/spark/blob/d7c3aae2/sql/catalyst/src/main/java/org/apache/spark/sql/execution/RecordBinaryComparator.java
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/java/org/apache/spark/sql/execution/RecordBinaryComparator.java
 
b/sql/catalyst/src/main/java/org/apache/spark/sql/execution/RecordBinaryComparator.java
new file mode 100644
index 0000000..40c2cc8
--- /dev/null
+++ 
b/sql/catalyst/src/main/java/org/apache/spark/sql/execution/RecordBinaryComparator.java
@@ -0,0 +1,74 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution;
+
+import org.apache.spark.unsafe.Platform;
+import org.apache.spark.util.collection.unsafe.sort.RecordComparator;
+
+public final class RecordBinaryComparator extends RecordComparator {
+
+  @Override
+  public int compare(
+      Object leftObj, long leftOff, int leftLen, Object rightObj, long 
rightOff, int rightLen) {
+    int i = 0;
+
+    // If the arrays have different length, the longer one is larger.
+    if (leftLen != rightLen) {
+      return leftLen - rightLen;
+    }
+
+    // The following logic uses `leftLen` as the length for both `leftObj` and 
`rightObj`, since
+    // we have guaranteed `leftLen` == `rightLen`.
+
+    // check if stars align and we can get both offsets to be aligned
+    if ((leftOff % 8) == (rightOff % 8)) {
+      while ((leftOff + i) % 8 != 0 && i < leftLen) {
+        final int v1 = Platform.getByte(leftObj, leftOff + i) & 0xff;
+        final int v2 = Platform.getByte(rightObj, rightOff + i) & 0xff;
+        if (v1 != v2) {
+          return v1 > v2 ? 1 : -1;
+        }
+        i += 1;
+      }
+    }
+    // for architectures that support unaligned accesses, chew it up 8 bytes 
at a time
+    if (Platform.unaligned() || (((leftOff + i) % 8 == 0) && ((rightOff + i) % 
8 == 0))) {
+      while (i <= leftLen - 8) {
+        final long v1 = Platform.getLong(leftObj, leftOff + i);
+        final long v2 = Platform.getLong(rightObj, rightOff + i);
+        if (v1 != v2) {
+          return v1 > v2 ? 1 : -1;
+        }
+        i += 8;
+      }
+    }
+    // this will finish off the unaligned comparisons, or do the entire 
aligned comparison
+    // whichever is needed.
+    while (i < leftLen) {
+      final int v1 = Platform.getByte(leftObj, leftOff + i) & 0xff;
+      final int v2 = Platform.getByte(rightObj, rightOff + i) & 0xff;
+      if (v1 != v2) {
+        return v1 > v2 ? 1 : -1;
+      }
+      i += 1;
+    }
+
+    // The two arrays are equal.
+    return 0;
+  }
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/d7c3aae2/sql/catalyst/src/main/java/org/apache/spark/sql/execution/UnsafeExternalRowSorter.java
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/java/org/apache/spark/sql/execution/UnsafeExternalRowSorter.java
 
b/sql/catalyst/src/main/java/org/apache/spark/sql/execution/UnsafeExternalRowSorter.java
index c29b002..54cec60 100644
--- 
a/sql/catalyst/src/main/java/org/apache/spark/sql/execution/UnsafeExternalRowSorter.java
+++ 
b/sql/catalyst/src/main/java/org/apache/spark/sql/execution/UnsafeExternalRowSorter.java
@@ -18,6 +18,7 @@
 package org.apache.spark.sql.execution;
 
 import java.io.IOException;
+import java.util.function.Supplier;
 
 import scala.collection.Iterator;
 import scala.math.Ordering;
@@ -55,26 +56,50 @@ public final class UnsafeExternalRowSorter {
 
     public static class Prefix {
       /** Key prefix value, or the null prefix value if isNull = true. **/
-      long value;
+      public long value;
 
       /** Whether the key is null. */
-      boolean isNull;
+      public boolean isNull;
     }
 
     /**
      * Computes prefix for the given row. For efficiency, the returned object 
may be reused in
      * further calls to a given PrefixComputer.
      */
-    abstract Prefix computePrefix(InternalRow row);
+    public abstract Prefix computePrefix(InternalRow row);
   }
 
-  public UnsafeExternalRowSorter(
+  public static UnsafeExternalRowSorter createWithRecordComparator(
+      StructType schema,
+      Supplier<RecordComparator> recordComparatorSupplier,
+      PrefixComparator prefixComparator,
+      PrefixComputer prefixComputer,
+      long pageSizeBytes,
+      boolean canUseRadixSort) throws IOException {
+    return new UnsafeExternalRowSorter(schema, recordComparatorSupplier, 
prefixComparator,
+      prefixComputer, pageSizeBytes, canUseRadixSort);
+  }
+
+  public static UnsafeExternalRowSorter create(
       StructType schema,
       Ordering<InternalRow> ordering,
       PrefixComparator prefixComparator,
       PrefixComputer prefixComputer,
       long pageSizeBytes,
       boolean canUseRadixSort) throws IOException {
+    Supplier<RecordComparator> recordComparatorSupplier =
+      () -> new RowComparator(ordering, schema.length());
+    return new UnsafeExternalRowSorter(schema, recordComparatorSupplier, 
prefixComparator,
+      prefixComputer, pageSizeBytes, canUseRadixSort);
+  }
+
+  private UnsafeExternalRowSorter(
+      StructType schema,
+      Supplier<RecordComparator> recordComparatorSupplier,
+      PrefixComparator prefixComparator,
+      PrefixComputer prefixComputer,
+      long pageSizeBytes,
+      boolean canUseRadixSort) throws IOException {
     this.schema = schema;
     this.prefixComputer = prefixComputer;
     final SparkEnv sparkEnv = SparkEnv.get();
@@ -84,7 +109,7 @@ public final class UnsafeExternalRowSorter {
       sparkEnv.blockManager(),
       sparkEnv.serializerManager(),
       taskContext,
-      new RowComparator(ordering, schema.length()),
+      recordComparatorSupplier.get(),
       prefixComparator,
       sparkEnv.conf().getInt("spark.shuffle.sort.initialBufferSize",
                              DEFAULT_INITIAL_SORT_BUFFER_SIZE),
@@ -207,8 +232,15 @@ public final class UnsafeExternalRowSorter {
     }
 
     @Override
-    public int compare(Object baseObj1, long baseOff1, Object baseObj2, long 
baseOff2) {
-      // TODO: Why are the sizes -1?
+    public int compare(
+        Object baseObj1,
+        long baseOff1,
+        int baseLen1,
+        Object baseObj2,
+        long baseOff2,
+        int baseLen2) {
+      // Note that since ordering doesn't need the total length of the record, 
we just pass -1
+      // into the row.
       row1.pointTo(baseObj1, baseOff1, -1);
       row2.pointTo(baseObj2, baseOff2, -1);
       return ordering.compare(row1, row2);

http://git-wip-us.apache.org/repos/asf/spark/blob/d7c3aae2/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala
index ebabd1a..9db5acd 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala
@@ -829,6 +829,18 @@ object SQLConf {
     .regexConf
     .createWithDefault("(?i)url".r)
 
+  val SORT_BEFORE_REPARTITION =
+    buildConf("spark.sql.execution.sortBeforeRepartition")
+      .internal()
+      .doc("When perform a repartition following a shuffle, the output row 
ordering would be " +
+        "nondeterministic. If some downstream stages fail and some tasks of 
the repartition " +
+        "stage retry, these tasks may generate different data, and that can 
lead to correctness " +
+        "issues. Turn on this config to insert a local sort before actually 
doing repartition " +
+        "to generate consistent repartition results. The performance of 
repartition() may go " +
+        "down since we insert extra local sort before it.")
+      .booleanConf
+      .createWithDefault(true)
+
   object Deprecated {
     val MAPRED_REDUCE_TASKS = "mapred.reduce.tasks"
   }
@@ -961,6 +973,8 @@ class SQLConf extends Serializable with Logging {
 
   def escapedStringLiterals: Boolean = getConf(ESCAPED_STRING_LITERALS)
 
+  def sortBeforeRepartition: Boolean = getConf(SORT_BEFORE_REPARTITION)
+
   /**
    * Returns the [[Resolver]] for the current configuration, which can be used 
to determine if two
    * identifiers are equal.

http://git-wip-us.apache.org/repos/asf/spark/blob/d7c3aae2/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeKVExternalSorter.java
----------------------------------------------------------------------
diff --git 
a/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeKVExternalSorter.java
 
b/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeKVExternalSorter.java
index 7d67b87..7549dec 100644
--- 
a/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeKVExternalSorter.java
+++ 
b/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeKVExternalSorter.java
@@ -252,8 +252,14 @@ public final class UnsafeKVExternalSorter {
     }
 
     @Override
-    public int compare(Object baseObj1, long baseOff1, Object baseObj2, long 
baseOff2) {
-      // Note that since ordering doesn't need the total length of the record, 
we just pass -1
+    public int compare(
+        Object baseObj1,
+        long baseOff1,
+        int baseLen1,
+        Object baseObj2,
+        long baseOff2,
+        int baseLen2) {
+      // Note that since ordering doesn't need the total length of the record, 
we just pass -1 
       // into the row.
       row1.pointTo(baseObj1, baseOff1 + 4, -1);
       row2.pointTo(baseObj2, baseOff2 + 4, -1);

http://git-wip-us.apache.org/repos/asf/spark/blob/d7c3aae2/sql/core/src/main/scala/org/apache/spark/sql/execution/SortExec.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/SortExec.scala 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/SortExec.scala
index f98ae82..d225979 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SortExec.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SortExec.scala
@@ -84,7 +84,7 @@ case class SortExec(
     }
 
     val pageSize = SparkEnv.get.memoryManager.pageSizeBytes
-    val sorter = new UnsafeExternalRowSorter(
+    val sorter = UnsafeExternalRowSorter.create(
       schema, ordering, prefixComparator, prefixComputer, pageSize, 
canUseRadixSort)
 
     if (testSpillFrequency > 0) {

http://git-wip-us.apache.org/repos/asf/spark/blob/d7c3aae2/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ShuffleExchange.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ShuffleExchange.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ShuffleExchange.scala
index eebe6ad..c0ba513 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ShuffleExchange.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ShuffleExchange.scala
@@ -18,6 +18,7 @@
 package org.apache.spark.sql.execution.exchange
 
 import java.util.Random
+import java.util.function.Supplier
 
 import org.apache.spark._
 import org.apache.spark.rdd.RDD
@@ -30,7 +31,10 @@ import 
org.apache.spark.sql.catalyst.expressions.codegen.LazilyGeneratedOrdering
 import org.apache.spark.sql.catalyst.plans.physical._
 import org.apache.spark.sql.execution._
 import org.apache.spark.sql.execution.metric.SQLMetrics
+import org.apache.spark.sql.internal.SQLConf
+import org.apache.spark.sql.types.StructType
 import org.apache.spark.util.MutablePair
+import org.apache.spark.util.collection.unsafe.sort.{PrefixComparators, 
RecordComparator}
 
 /**
  * Performs a shuffle that will result in the desired `newPartitioning`.
@@ -242,14 +246,61 @@ object ShuffleExchange {
       case RangePartitioning(_, _) | SinglePartition => identity
       case _ => sys.error(s"Exchange not implemented for $newPartitioning")
     }
+
     val rddWithPartitionIds: RDD[Product2[Int, InternalRow]] = {
-      if (needToCopyObjectsBeforeShuffle(part, serializer)) {
+      // [SPARK-23207] Have to make sure the generated RoundRobinPartitioning 
is deterministic,
+      // otherwise a retry task may output different rows and thus lead to 
data loss.
+      //
+      // Currently we following the most straight-forward way that perform a 
local sort before
+      // partitioning.
+      //
+      // Note that we don't perform local sort if the new partitioning has 
only 1 partition, under
+      // that case all output rows go to the same partition.
+      val newRdd = if (SparkEnv.get.conf.get(SQLConf.SORT_BEFORE_REPARTITION) 
&&
+          newPartitioning.numPartitions > 1 &&
+          newPartitioning.isInstanceOf[RoundRobinPartitioning]) {
         rdd.mapPartitionsInternal { iter =>
+          val recordComparatorSupplier = new Supplier[RecordComparator] {
+            override def get: RecordComparator = new RecordBinaryComparator()
+          }
+          // The comparator for comparing row hashcode, which should always be 
Integer.
+          val prefixComparator = PrefixComparators.LONG
+          val canUseRadixSort = 
SparkEnv.get.conf.get(SQLConf.RADIX_SORT_ENABLED)
+          // The prefix computer generates row hashcode as the prefix, so we 
may decrease the
+          // probability that the prefixes are equal when input rows choose 
column values from a
+          // limited range.
+          val prefixComputer = new UnsafeExternalRowSorter.PrefixComputer {
+            private val result = new 
UnsafeExternalRowSorter.PrefixComputer.Prefix
+            override def computePrefix(row: InternalRow):
+            UnsafeExternalRowSorter.PrefixComputer.Prefix = {
+              // The hashcode generated from the binary form of a 
[[UnsafeRow]] should not be null.
+              result.isNull = false
+              result.value = row.hashCode()
+              result
+            }
+          }
+          val pageSize = SparkEnv.get.memoryManager.pageSizeBytes
+
+          val sorter = UnsafeExternalRowSorter.createWithRecordComparator(
+            StructType.fromAttributes(outputAttributes),
+            recordComparatorSupplier,
+            prefixComparator,
+            prefixComputer,
+            pageSize,
+            canUseRadixSort)
+          sorter.sort(iter.asInstanceOf[Iterator[UnsafeRow]])
+        }
+      } else {
+        rdd
+      }
+
+      if (needToCopyObjectsBeforeShuffle(part, serializer)) {
+        newRdd.mapPartitionsInternal { iter =>
           val getPartitionKey = getPartitionKeyExtractor()
           iter.map { row => (part.getPartition(getPartitionKey(row)), 
row.copy()) }
         }
       } else {
-        rdd.mapPartitionsInternal { iter =>
+        newRdd.mapPartitionsInternal { iter =>
           val getPartitionKey = getPartitionKeyExtractor()
           val mutablePair = new MutablePair[Int, InternalRow]()
           iter.map { row => 
mutablePair.update(part.getPartition(getPartitionKey(row)), row) }

http://git-wip-us.apache.org/repos/asf/spark/blob/d7c3aae2/sql/core/src/test/java/test/org/apache/spark/sql/execution/sort/RecordBinaryComparatorSuite.java
----------------------------------------------------------------------
diff --git 
a/sql/core/src/test/java/test/org/apache/spark/sql/execution/sort/RecordBinaryComparatorSuite.java
 
b/sql/core/src/test/java/test/org/apache/spark/sql/execution/sort/RecordBinaryComparatorSuite.java
new file mode 100644
index 0000000..97f3dc5
--- /dev/null
+++ 
b/sql/core/src/test/java/test/org/apache/spark/sql/execution/sort/RecordBinaryComparatorSuite.java
@@ -0,0 +1,322 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package test.org.apache.spark.sql.execution.sort;
+
+import org.apache.spark.SparkConf;
+import org.apache.spark.memory.TaskMemoryManager;
+import org.apache.spark.memory.TestMemoryConsumer;
+import org.apache.spark.memory.TestMemoryManager;
+import org.apache.spark.sql.catalyst.expressions.UnsafeArrayData;
+import org.apache.spark.sql.catalyst.expressions.UnsafeRow;
+import org.apache.spark.sql.execution.RecordBinaryComparator;
+import org.apache.spark.unsafe.Platform;
+import org.apache.spark.unsafe.UnsafeAlignedOffset;
+import org.apache.spark.unsafe.array.LongArray;
+import org.apache.spark.unsafe.memory.MemoryBlock;
+import org.apache.spark.unsafe.types.UTF8String;
+import org.apache.spark.util.collection.unsafe.sort.*;
+
+import org.junit.After;
+import org.junit.Before;
+import org.junit.Test;
+
+/**
+ * Test the RecordBinaryComparator, which compares two UnsafeRows by their 
binary form.
+ */
+public class RecordBinaryComparatorSuite {
+
+  private final TaskMemoryManager memoryManager = new TaskMemoryManager(
+      new TestMemoryManager(new 
SparkConf().set("spark.memory.offHeap.enabled", "false")), 0);
+  private final TestMemoryConsumer consumer = new 
TestMemoryConsumer(memoryManager);
+
+  private final int uaoSize = UnsafeAlignedOffset.getUaoSize();
+
+  private MemoryBlock dataPage;
+  private long pageCursor;
+
+  private LongArray array;
+  private int pos;
+
+  @Before
+  public void beforeEach() {
+    // Only compare between two input rows.
+    array = consumer.allocateArray(2);
+    pos = 0;
+
+    dataPage = memoryManager.allocatePage(4096, consumer);
+    pageCursor = dataPage.getBaseOffset();
+  }
+
+  @After
+  public void afterEach() {
+    consumer.freePage(dataPage);
+    dataPage = null;
+    pageCursor = 0;
+
+    consumer.freeArray(array);
+    array = null;
+    pos = 0;
+  }
+
+  private void insertRow(UnsafeRow row) {
+    Object recordBase = row.getBaseObject();
+    long recordOffset = row.getBaseOffset();
+    int recordLength = row.getSizeInBytes();
+
+    Object baseObject = dataPage.getBaseObject();
+    assert(pageCursor + recordLength <= dataPage.getBaseOffset() + 
dataPage.size());
+    long recordAddress = memoryManager.encodePageNumberAndOffset(dataPage, 
pageCursor);
+    UnsafeAlignedOffset.putSize(baseObject, pageCursor, recordLength);
+    pageCursor += uaoSize;
+    Platform.copyMemory(recordBase, recordOffset, baseObject, pageCursor, 
recordLength);
+    pageCursor += recordLength;
+
+    assert(pos < 2);
+    array.set(pos, recordAddress);
+    pos++;
+  }
+
+  private int compare(int index1, int index2) {
+    Object baseObject = dataPage.getBaseObject();
+
+    long recordAddress1 = array.get(index1);
+    long baseOffset1 = memoryManager.getOffsetInPage(recordAddress1) + uaoSize;
+    int recordLength1 = UnsafeAlignedOffset.getSize(baseObject, baseOffset1 - 
uaoSize);
+
+    long recordAddress2 = array.get(index2);
+    long baseOffset2 = memoryManager.getOffsetInPage(recordAddress2) + uaoSize;
+    int recordLength2 = UnsafeAlignedOffset.getSize(baseObject, baseOffset2 - 
uaoSize);
+
+    return binaryComparator.compare(baseObject, baseOffset1, recordLength1, 
baseObject,
+        baseOffset2, recordLength2);
+  }
+
+  private final RecordComparator binaryComparator = new 
RecordBinaryComparator();
+
+  // Compute the most compact size for UnsafeRow's backing data.
+  private int computeSizeInBytes(int originalSize) {
+    // All the UnsafeRows in this suite contains less than 64 columns, so the 
bitSetSize shall
+    // always be 8.
+    return 8 + (originalSize + 7) / 8 * 8;
+  }
+
+  // Compute the relative offset of variable-length values.
+  private long relativeOffset(int numFields) {
+    // All the UnsafeRows in this suite contains less than 64 columns, so the 
bitSetSize shall
+    // always be 8.
+    return 8 + numFields * 8L;
+  }
+
+  @Test
+  public void testBinaryComparatorForSingleColumnRow() throws Exception {
+    int numFields = 1;
+
+    UnsafeRow row1 = new UnsafeRow(numFields);
+    byte[] data1 = new byte[100];
+    row1.pointTo(data1, computeSizeInBytes(numFields * 8));
+    row1.setInt(0, 11);
+
+    UnsafeRow row2 = new UnsafeRow(numFields);
+    byte[] data2 = new byte[100];
+    row2.pointTo(data2, computeSizeInBytes(numFields * 8));
+    row2.setInt(0, 42);
+
+    insertRow(row1);
+    insertRow(row2);
+
+    assert(compare(0, 0) == 0);
+    assert(compare(0, 1) < 0);
+  }
+
+  @Test
+  public void testBinaryComparatorForMultipleColumnRow() throws Exception {
+    int numFields = 5;
+
+    UnsafeRow row1 = new UnsafeRow(numFields);
+    byte[] data1 = new byte[100];
+    row1.pointTo(data1, computeSizeInBytes(numFields * 8));
+    for (int i = 0; i < numFields; i++) {
+      row1.setDouble(i, i * 3.14);
+    }
+
+    UnsafeRow row2 = new UnsafeRow(numFields);
+    byte[] data2 = new byte[100];
+    row2.pointTo(data2, computeSizeInBytes(numFields * 8));
+    for (int i = 0; i < numFields; i++) {
+      row2.setDouble(i, 198.7 / (i + 1));
+    }
+
+    insertRow(row1);
+    insertRow(row2);
+
+    assert(compare(0, 0) == 0);
+    assert(compare(0, 1) < 0);
+  }
+
+  @Test
+  public void testBinaryComparatorForArrayColumn() throws Exception {
+    int numFields = 1;
+
+    UnsafeRow row1 = new UnsafeRow(numFields);
+    byte[] data1 = new byte[100];
+    UnsafeArrayData arrayData1 = UnsafeArrayData.fromPrimitiveArray(new 
int[]{11, 42, -1});
+    row1.pointTo(data1, computeSizeInBytes(numFields * 8 + 
arrayData1.getSizeInBytes()));
+    row1.setLong(0, (relativeOffset(numFields) << 32) | (long) 
arrayData1.getSizeInBytes());
+    Platform.copyMemory(arrayData1.getBaseObject(), 
arrayData1.getBaseOffset(), data1,
+        row1.getBaseOffset() + relativeOffset(numFields), 
arrayData1.getSizeInBytes());
+
+    UnsafeRow row2 = new UnsafeRow(numFields);
+    byte[] data2 = new byte[100];
+    UnsafeArrayData arrayData2 = UnsafeArrayData.fromPrimitiveArray(new 
int[]{22});
+    row2.pointTo(data2, computeSizeInBytes(numFields * 8 + 
arrayData2.getSizeInBytes()));
+    row2.setLong(0, (relativeOffset(numFields) << 32) | (long) 
arrayData2.getSizeInBytes());
+    Platform.copyMemory(arrayData2.getBaseObject(), 
arrayData2.getBaseOffset(), data2,
+        row2.getBaseOffset() + relativeOffset(numFields), 
arrayData2.getSizeInBytes());
+
+    insertRow(row1);
+    insertRow(row2);
+
+    assert(compare(0, 0) == 0);
+    assert(compare(0, 1) > 0);
+  }
+
+  @Test
+  public void testBinaryComparatorForMixedColumns() throws Exception {
+    int numFields = 4;
+
+    UnsafeRow row1 = new UnsafeRow(numFields);
+    byte[] data1 = new byte[100];
+    UTF8String str1 = UTF8String.fromString("Milk tea");
+    row1.pointTo(data1, computeSizeInBytes(numFields * 8 + str1.numBytes()));
+    row1.setInt(0, 11);
+    row1.setDouble(1, 3.14);
+    row1.setInt(2, -1);
+    row1.setLong(3, (relativeOffset(numFields) << 32) | (long) 
str1.numBytes());
+    Platform.copyMemory(str1.getBaseObject(), str1.getBaseOffset(), data1,
+        row1.getBaseOffset() + relativeOffset(numFields), str1.numBytes());
+
+    UnsafeRow row2 = new UnsafeRow(numFields);
+    byte[] data2 = new byte[100];
+    UTF8String str2 = UTF8String.fromString("Java");
+    row2.pointTo(data2, computeSizeInBytes(numFields * 8 + str2.numBytes()));
+    row2.setInt(0, 11);
+    row2.setDouble(1, 3.14);
+    row2.setInt(2, -1);
+    row2.setLong(3, (relativeOffset(numFields) << 32) | (long) 
str2.numBytes());
+    Platform.copyMemory(str2.getBaseObject(), str2.getBaseOffset(), data2,
+        row2.getBaseOffset() + relativeOffset(numFields), str2.numBytes());
+
+    insertRow(row1);
+    insertRow(row2);
+
+    assert(compare(0, 0) == 0);
+    assert(compare(0, 1) > 0);
+  }
+
+  @Test
+  public void testBinaryComparatorForNullColumns() throws Exception {
+    int numFields = 3;
+
+    UnsafeRow row1 = new UnsafeRow(numFields);
+    byte[] data1 = new byte[100];
+    row1.pointTo(data1, computeSizeInBytes(numFields * 8));
+    for (int i = 0; i < numFields; i++) {
+      row1.setNullAt(i);
+    }
+
+    UnsafeRow row2 = new UnsafeRow(numFields);
+    byte[] data2 = new byte[100];
+    row2.pointTo(data2, computeSizeInBytes(numFields * 8));
+    for (int i = 0; i < numFields - 1; i++) {
+      row2.setNullAt(i);
+    }
+    row2.setDouble(numFields - 1, 3.14);
+
+    insertRow(row1);
+    insertRow(row2);
+
+    assert(compare(0, 0) == 0);
+    assert(compare(0, 1) > 0);
+  }
+
+  @Test
+  public void testBinaryComparatorWhenSubtractionIsDivisibleByMaxIntValue() 
throws Exception {
+    int numFields = 1;
+
+    UnsafeRow row1 = new UnsafeRow(numFields);
+    byte[] data1 = new byte[100];
+    row1.pointTo(data1, computeSizeInBytes(numFields * 8));
+    row1.setLong(0, 11);
+
+    UnsafeRow row2 = new UnsafeRow(numFields);
+    byte[] data2 = new byte[100];
+    row2.pointTo(data2, computeSizeInBytes(numFields * 8));
+    row2.setLong(0, 11L + Integer.MAX_VALUE);
+
+    insertRow(row1);
+    insertRow(row2);
+
+    assert(compare(0, 1) < 0);
+  }
+
+  @Test
+  public void testBinaryComparatorWhenSubtractionCanOverflowLongValue() throws 
Exception {
+    int numFields = 1;
+
+    UnsafeRow row1 = new UnsafeRow(numFields);
+    byte[] data1 = new byte[100];
+    row1.pointTo(data1, computeSizeInBytes(numFields * 8));
+    row1.setLong(0, Long.MIN_VALUE);
+
+    UnsafeRow row2 = new UnsafeRow(numFields);
+    byte[] data2 = new byte[100];
+    row2.pointTo(data2, computeSizeInBytes(numFields * 8));
+    row2.setLong(0, 1);
+
+    insertRow(row1);
+    insertRow(row2);
+
+    assert(compare(0, 1) < 0);
+  }
+
+  @Test
+  public void testBinaryComparatorWhenOnlyTheLastColumnDiffers() throws 
Exception {
+    int numFields = 4;
+
+    UnsafeRow row1 = new UnsafeRow(numFields);
+    byte[] data1 = new byte[100];
+    row1.pointTo(data1, computeSizeInBytes(numFields * 8));
+    row1.setInt(0, 11);
+    row1.setDouble(1, 3.14);
+    row1.setInt(2, -1);
+    row1.setLong(3, 0);
+
+    UnsafeRow row2 = new UnsafeRow(numFields);
+    byte[] data2 = new byte[100];
+    row2.pointTo(data2, computeSizeInBytes(numFields * 8));
+    row2.setInt(0, 11);
+    row2.setDouble(1, 3.14);
+    row2.setInt(2, -1);
+    row2.setLong(3, 1);
+
+    insertRow(row1);
+    insertRow(row2);
+
+    assert(compare(0, 1) < 0);
+  }
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/d7c3aae2/sql/core/src/test/scala/org/apache/spark/sql/execution/ExchangeSuite.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/ExchangeSuite.scala 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/ExchangeSuite.scala
index 59eaf4d..abd3e6c 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/ExchangeSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/ExchangeSuite.scala
@@ -17,11 +17,14 @@
 
 package org.apache.spark.sql.execution
 
-import org.apache.spark.sql.Row
+import scala.util.Random
+
+import org.apache.spark.sql.{Dataset, Row}
 import org.apache.spark.sql.catalyst.expressions.{Alias, Literal}
 import org.apache.spark.sql.catalyst.plans.physical.{HashPartitioning, 
IdentityBroadcastMode, SinglePartition}
 import org.apache.spark.sql.execution.exchange.{BroadcastExchangeExec, 
ReusedExchangeExec, ShuffleExchange}
 import org.apache.spark.sql.execution.joins.HashedRelationBroadcastMode
+import org.apache.spark.sql.internal.SQLConf
 import org.apache.spark.sql.test.SharedSQLContext
 
 class ExchangeSuite extends SparkPlanTest with SharedSQLContext {
@@ -101,4 +104,25 @@ class ExchangeSuite extends SparkPlanTest with 
SharedSQLContext {
     assert(exchange4.sameResult(exchange5))
     assert(exchange5 sameResult exchange4)
   }
+
+  test("SPARK-23207: Make repartition() generate consistent output") {
+    def assertConsistency(ds: Dataset[java.lang.Long]): Unit = {
+      ds.persist()
+
+      val exchange = ds.mapPartitions { iter =>
+        Random.shuffle(iter)
+      }.repartition(111)
+      val exchange2 = ds.repartition(111)
+
+      assert(exchange.rdd.collectPartitions() === 
exchange2.rdd.collectPartitions())
+    }
+
+    withSQLConf(SQLConf.SORT_BEFORE_REPARTITION.key -> "true") {
+      // repartition() should generate consistent output.
+      assertConsistency(spark.range(10000))
+
+      // case when input contains duplicated rows.
+      assertConsistency(spark.range(10000).map(i => 
Random.nextInt(1000).toLong))
+    }
+  }
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/d7c3aae2/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetIOSuite.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetIOSuite.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetIOSuite.scala
index 94a2f9a..34f00aa 100644
--- 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetIOSuite.scala
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetIOSuite.scala
@@ -661,7 +661,7 @@ class ParquetIOSuite extends QueryTest with ParquetTest 
with SharedSQLContext {
             val v = (row.getInt(0), row.getString(1))
             result += v
           }
-          assert(data == result)
+          assert(data.toSet == result.toSet)
         } finally {
           reader.close()
         }
@@ -677,7 +677,7 @@ class ParquetIOSuite extends QueryTest with ParquetTest 
with SharedSQLContext {
             val row = reader.getCurrentValue.asInstanceOf[InternalRow]
             result += row.getString(0)
           }
-          assert(data.map(_._2) == result)
+          assert(data.map(_._2).toSet == result.toSet)
         } finally {
           reader.close()
         }
@@ -694,7 +694,7 @@ class ParquetIOSuite extends QueryTest with ParquetTest 
with SharedSQLContext {
             val v = (row.getString(0), row.getInt(1))
             result += v
           }
-          assert(data.map { x => (x._2, x._1) } == result)
+          assert(data.map { x => (x._2, x._1) }.toSet == result.toSet)
         } finally {
           reader.close()
         }

http://git-wip-us.apache.org/repos/asf/spark/blob/d7c3aae2/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/ForeachSinkSuite.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/ForeachSinkSuite.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/ForeachSinkSuite.scala
index 9137d65..41434e6 100644
--- 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/ForeachSinkSuite.scala
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/ForeachSinkSuite.scala
@@ -52,13 +52,13 @@ class ForeachSinkSuite extends StreamTest with 
SharedSQLContext with BeforeAndAf
 
       var expectedEventsForPartition0 = Seq(
         ForeachSinkSuite.Open(partition = 0, version = 0),
-        ForeachSinkSuite.Process(value = 1),
+        ForeachSinkSuite.Process(value = 2),
         ForeachSinkSuite.Process(value = 3),
         ForeachSinkSuite.Close(None)
       )
       var expectedEventsForPartition1 = Seq(
         ForeachSinkSuite.Open(partition = 1, version = 0),
-        ForeachSinkSuite.Process(value = 2),
+        ForeachSinkSuite.Process(value = 1),
         ForeachSinkSuite.Process(value = 4),
         ForeachSinkSuite.Close(None)
       )


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to