Repository: spark Updated Branches: refs/heads/master 9abe09bfc -> 4f1758509
[SPARK-19355][SQL] Use map output statistics to improve global limit's parallelism ## What changes were proposed in this pull request? A logical `Limit` is performed physically by two operations `LocalLimit` and `GlobalLimit`. Most of time, we gather all data into a single partition in order to run `GlobalLimit`. If we use a very big limit number, shuffling data causes performance issue also reduces parallelism. We can avoid shuffling into single partition if we don't care data ordering. This patch implements this idea by doing a map stage during global limit. It collects the info of row numbers at each partition. For each partition, we locally retrieves limited data without any shuffling to finish this global limit. For example, we have three partitions with rows (100, 100, 50) respectively. In global limit of 100 rows, we may take (34, 33, 33) rows for each partition locally. After global limit we still have three partitions. If the data partition has certain ordering, we can't distribute required rows evenly to each partitions because it could change data ordering. But we still can avoid shuffling. ## How was this patch tested? Jenkins tests. Author: Liang-Chi Hsieh <vii...@gmail.com> Closes #16677 from viirya/improve-global-limit-parallelism. Project: http://git-wip-us.apache.org/repos/asf/spark/repo Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/4f175850 Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/4f175850 Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/4f175850 Branch: refs/heads/master Commit: 4f175850985cfc4c64afb90d784bb292e81dc0b7 Parents: 9abe09b Author: Liang-Chi Hsieh <vii...@gmail.com> Authored: Fri Aug 10 11:32:15 2018 +0200 Committer: Herman van Hovell <hvanhov...@databricks.com> Committed: Fri Aug 10 11:32:15 2018 +0200 ---------------------------------------------------------------------- .../sort/BypassMergeSortShuffleWriter.java | 5 +- .../spark/shuffle/sort/UnsafeShuffleWriter.java | 3 +- .../org/apache/spark/MapOutputStatistics.scala | 6 +- .../org/apache/spark/MapOutputTracker.scala | 10 +- .../org/apache/spark/scheduler/MapStatus.scala | 43 +++++--- .../spark/shuffle/sort/SortShuffleWriter.scala | 3 +- .../shuffle/sort/UnsafeShuffleWriterSuite.java | 2 + .../apache/spark/MapOutputTrackerSuite.scala | 28 ++--- .../scala/org/apache/spark/ShuffleSuite.scala | 1 + .../spark/scheduler/DAGSchedulerSuite.scala | 10 +- .../apache/spark/scheduler/MapStatusSuite.scala | 16 +-- .../spark/serializer/KryoSerializerSuite.scala | 3 +- .../catalyst/plans/physical/partitioning.scala | 14 +++ .../org/apache/spark/sql/internal/SQLConf.scala | 9 ++ .../exchange/ShuffleExchangeExec.scala | 8 ++ .../org/apache/spark/sql/execution/limit.scala | 101 ++++++++++++++++--- .../test/resources/sql-tests/inputs/limit.sql | 2 + .../inputs/subquery/in-subquery/in-limit.sql | 5 +- .../resources/sql-tests/results/limit.sql.out | 92 +++++++++-------- .../subquery/in-subquery/in-limit.sql.out | 56 +++++----- .../spark/sql/DataFrameAggregateSuite.scala | 12 ++- .../org/apache/spark/sql/SQLQuerySuite.scala | 11 +- .../execution/ExchangeCoordinatorSuite.scala | 6 +- .../spark/sql/execution/PlannerSuite.scala | 4 +- .../hive/execution/HiveCompatibilitySuite.scala | 4 + .../spark/sql/hive/execution/PruningSuite.scala | 8 ++ 26 files changed, 322 insertions(+), 140 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/spark/blob/4f175850/core/src/main/java/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriter.java ---------------------------------------------------------------------- diff --git a/core/src/main/java/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriter.java b/core/src/main/java/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriter.java index 323a5d3..e3bd549 100644 --- a/core/src/main/java/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriter.java +++ b/core/src/main/java/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriter.java @@ -125,7 +125,7 @@ final class BypassMergeSortShuffleWriter<K, V> extends ShuffleWriter<K, V> { if (!records.hasNext()) { partitionLengths = new long[numPartitions]; shuffleBlockResolver.writeIndexFileAndCommit(shuffleId, mapId, partitionLengths, null); - mapStatus = MapStatus$.MODULE$.apply(blockManager.shuffleServerId(), partitionLengths); + mapStatus = MapStatus$.MODULE$.apply(blockManager.shuffleServerId(), partitionLengths, 0); return; } final SerializerInstance serInstance = serializer.newInstance(); @@ -167,7 +167,8 @@ final class BypassMergeSortShuffleWriter<K, V> extends ShuffleWriter<K, V> { logger.error("Error while deleting temp file {}", tmp.getAbsolutePath()); } } - mapStatus = MapStatus$.MODULE$.apply(blockManager.shuffleServerId(), partitionLengths); + mapStatus = MapStatus$.MODULE$.apply( + blockManager.shuffleServerId(), partitionLengths, writeMetrics.recordsWritten()); } @VisibleForTesting http://git-wip-us.apache.org/repos/asf/spark/blob/4f175850/core/src/main/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriter.java ---------------------------------------------------------------------- diff --git a/core/src/main/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriter.java b/core/src/main/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriter.java index 4839d04..069e6d5 100644 --- a/core/src/main/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriter.java +++ b/core/src/main/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriter.java @@ -248,7 +248,8 @@ public class UnsafeShuffleWriter<K, V> extends ShuffleWriter<K, V> { logger.error("Error while deleting temp file {}", tmp.getAbsolutePath()); } } - mapStatus = MapStatus$.MODULE$.apply(blockManager.shuffleServerId(), partitionLengths); + mapStatus = MapStatus$.MODULE$.apply( + blockManager.shuffleServerId(), partitionLengths, writeMetrics.recordsWritten()); } @VisibleForTesting http://git-wip-us.apache.org/repos/asf/spark/blob/4f175850/core/src/main/scala/org/apache/spark/MapOutputStatistics.scala ---------------------------------------------------------------------- diff --git a/core/src/main/scala/org/apache/spark/MapOutputStatistics.scala b/core/src/main/scala/org/apache/spark/MapOutputStatistics.scala index f8a6f1d..ff85e11 100644 --- a/core/src/main/scala/org/apache/spark/MapOutputStatistics.scala +++ b/core/src/main/scala/org/apache/spark/MapOutputStatistics.scala @@ -23,5 +23,9 @@ package org.apache.spark * @param shuffleId ID of the shuffle * @param bytesByPartitionId approximate number of output bytes for each map output partition * (may be inexact due to use of compressed map statuses) + * @param recordsByPartitionId number of output records for each map output partition */ -private[spark] class MapOutputStatistics(val shuffleId: Int, val bytesByPartitionId: Array[Long]) +private[spark] class MapOutputStatistics( + val shuffleId: Int, + val bytesByPartitionId: Array[Long], + val recordsByPartitionId: Array[Long]) http://git-wip-us.apache.org/repos/asf/spark/blob/4f175850/core/src/main/scala/org/apache/spark/MapOutputTracker.scala ---------------------------------------------------------------------- diff --git a/core/src/main/scala/org/apache/spark/MapOutputTracker.scala b/core/src/main/scala/org/apache/spark/MapOutputTracker.scala index 1c4fa4b..41575ce 100644 --- a/core/src/main/scala/org/apache/spark/MapOutputTracker.scala +++ b/core/src/main/scala/org/apache/spark/MapOutputTracker.scala @@ -522,16 +522,19 @@ private[spark] class MapOutputTrackerMaster( def getStatistics(dep: ShuffleDependency[_, _, _]): MapOutputStatistics = { shuffleStatuses(dep.shuffleId).withMapStatuses { statuses => val totalSizes = new Array[Long](dep.partitioner.numPartitions) + val recordsByMapTask = new Array[Long](statuses.length) + val parallelAggThreshold = conf.get( SHUFFLE_MAP_OUTPUT_PARALLEL_AGGREGATION_THRESHOLD) val parallelism = math.min( Runtime.getRuntime.availableProcessors(), statuses.length.toLong * totalSizes.length / parallelAggThreshold + 1).toInt if (parallelism <= 1) { - for (s <- statuses) { + statuses.zipWithIndex.foreach { case (s, index) => for (i <- 0 until totalSizes.length) { totalSizes(i) += s.getSizeForBlock(i) } + recordsByMapTask(index) = s.numberOfOutput } } else { val threadPool = ThreadUtils.newDaemonFixedThreadPool(parallelism, "map-output-aggregate") @@ -548,8 +551,11 @@ private[spark] class MapOutputTrackerMaster( } finally { threadPool.shutdown() } + statuses.zipWithIndex.foreach { case (s, index) => + recordsByMapTask(index) = s.numberOfOutput + } } - new MapOutputStatistics(dep.shuffleId, totalSizes) + new MapOutputStatistics(dep.shuffleId, totalSizes, recordsByMapTask) } } http://git-wip-us.apache.org/repos/asf/spark/blob/4f175850/core/src/main/scala/org/apache/spark/scheduler/MapStatus.scala ---------------------------------------------------------------------- diff --git a/core/src/main/scala/org/apache/spark/scheduler/MapStatus.scala b/core/src/main/scala/org/apache/spark/scheduler/MapStatus.scala index 659694d..7e1d75f 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/MapStatus.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/MapStatus.scala @@ -31,7 +31,8 @@ import org.apache.spark.util.Utils /** * Result returned by a ShuffleMapTask to a scheduler. Includes the block manager address that the - * task ran on as well as the sizes of outputs for each reducer, for passing on to the reduce tasks. + * task ran on, the sizes of outputs for each reducer, and the number of outputs of the map task, + * for passing on to the reduce tasks. */ private[spark] sealed trait MapStatus { /** Location where this task was run. */ @@ -44,18 +45,23 @@ private[spark] sealed trait MapStatus { * necessary for correctness, since block fetchers are allowed to skip zero-size blocks. */ def getSizeForBlock(reduceId: Int): Long + + /** + * The number of outputs for the map task. + */ + def numberOfOutput: Long } private[spark] object MapStatus { - def apply(loc: BlockManagerId, uncompressedSizes: Array[Long]): MapStatus = { + def apply(loc: BlockManagerId, uncompressedSizes: Array[Long], numOutput: Long): MapStatus = { if (uncompressedSizes.length > Option(SparkEnv.get) .map(_.conf.get(config.SHUFFLE_MIN_NUM_PARTS_TO_HIGHLY_COMPRESS)) .getOrElse(config.SHUFFLE_MIN_NUM_PARTS_TO_HIGHLY_COMPRESS.defaultValue.get)) { - HighlyCompressedMapStatus(loc, uncompressedSizes) + HighlyCompressedMapStatus(loc, uncompressedSizes, numOutput) } else { - new CompressedMapStatus(loc, uncompressedSizes) + new CompressedMapStatus(loc, uncompressedSizes, numOutput) } } @@ -98,29 +104,34 @@ private[spark] object MapStatus { */ private[spark] class CompressedMapStatus( private[this] var loc: BlockManagerId, - private[this] var compressedSizes: Array[Byte]) + private[this] var compressedSizes: Array[Byte], + private[this] var numOutput: Long) extends MapStatus with Externalizable { - protected def this() = this(null, null.asInstanceOf[Array[Byte]]) // For deserialization only + protected def this() = this(null, null.asInstanceOf[Array[Byte]], -1) // For deserialization only - def this(loc: BlockManagerId, uncompressedSizes: Array[Long]) { - this(loc, uncompressedSizes.map(MapStatus.compressSize)) + def this(loc: BlockManagerId, uncompressedSizes: Array[Long], numOutput: Long) { + this(loc, uncompressedSizes.map(MapStatus.compressSize), numOutput) } override def location: BlockManagerId = loc + override def numberOfOutput: Long = numOutput + override def getSizeForBlock(reduceId: Int): Long = { MapStatus.decompressSize(compressedSizes(reduceId)) } override def writeExternal(out: ObjectOutput): Unit = Utils.tryOrIOException { loc.writeExternal(out) + out.writeLong(numOutput) out.writeInt(compressedSizes.length) out.write(compressedSizes) } override def readExternal(in: ObjectInput): Unit = Utils.tryOrIOException { loc = BlockManagerId(in) + numOutput = in.readLong() val len = in.readInt() compressedSizes = new Array[Byte](len) in.readFully(compressedSizes) @@ -143,17 +154,20 @@ private[spark] class HighlyCompressedMapStatus private ( private[this] var numNonEmptyBlocks: Int, private[this] var emptyBlocks: RoaringBitmap, private[this] var avgSize: Long, - private var hugeBlockSizes: Map[Int, Byte]) + private var hugeBlockSizes: Map[Int, Byte], + private[this] var numOutput: Long) extends MapStatus with Externalizable { // loc could be null when the default constructor is called during deserialization require(loc == null || avgSize > 0 || hugeBlockSizes.size > 0 || numNonEmptyBlocks == 0, "Average size can only be zero for map stages that produced no output") - protected def this() = this(null, -1, null, -1, null) // For deserialization only + protected def this() = this(null, -1, null, -1, null, -1) // For deserialization only override def location: BlockManagerId = loc + override def numberOfOutput: Long = numOutput + override def getSizeForBlock(reduceId: Int): Long = { assert(hugeBlockSizes != null) if (emptyBlocks.contains(reduceId)) { @@ -168,6 +182,7 @@ private[spark] class HighlyCompressedMapStatus private ( override def writeExternal(out: ObjectOutput): Unit = Utils.tryOrIOException { loc.writeExternal(out) + out.writeLong(numOutput) emptyBlocks.writeExternal(out) out.writeLong(avgSize) out.writeInt(hugeBlockSizes.size) @@ -179,6 +194,7 @@ private[spark] class HighlyCompressedMapStatus private ( override def readExternal(in: ObjectInput): Unit = Utils.tryOrIOException { loc = BlockManagerId(in) + numOutput = in.readLong() emptyBlocks = new RoaringBitmap() emptyBlocks.readExternal(in) avgSize = in.readLong() @@ -194,7 +210,10 @@ private[spark] class HighlyCompressedMapStatus private ( } private[spark] object HighlyCompressedMapStatus { - def apply(loc: BlockManagerId, uncompressedSizes: Array[Long]): HighlyCompressedMapStatus = { + def apply( + loc: BlockManagerId, + uncompressedSizes: Array[Long], + numOutput: Long): HighlyCompressedMapStatus = { // We must keep track of which blocks are empty so that we don't report a zero-sized // block as being non-empty (or vice-versa) when using the average block size. var i = 0 @@ -235,6 +254,6 @@ private[spark] object HighlyCompressedMapStatus { emptyBlocks.trim() emptyBlocks.runOptimize() new HighlyCompressedMapStatus(loc, numNonEmptyBlocks, emptyBlocks, avgSize, - hugeBlockSizesArray.toMap) + hugeBlockSizesArray.toMap, numOutput) } } http://git-wip-us.apache.org/repos/asf/spark/blob/4f175850/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleWriter.scala ---------------------------------------------------------------------- diff --git a/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleWriter.scala b/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleWriter.scala index 274399b..91fc267 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleWriter.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleWriter.scala @@ -70,7 +70,8 @@ private[spark] class SortShuffleWriter[K, V, C]( val blockId = ShuffleBlockId(dep.shuffleId, mapId, IndexShuffleBlockResolver.NOOP_REDUCE_ID) val partitionLengths = sorter.writePartitionedFile(blockId, tmp) shuffleBlockResolver.writeIndexFileAndCommit(dep.shuffleId, mapId, partitionLengths, tmp) - mapStatus = MapStatus(blockManager.shuffleServerId, partitionLengths) + mapStatus = MapStatus(blockManager.shuffleServerId, partitionLengths, + writeMetrics.recordsWritten) } finally { if (tmp.exists() && !tmp.delete()) { logError(s"Error while deleting temp file ${tmp.getAbsolutePath}") http://git-wip-us.apache.org/repos/asf/spark/blob/4f175850/core/src/test/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriterSuite.java ---------------------------------------------------------------------- diff --git a/core/src/test/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriterSuite.java b/core/src/test/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriterSuite.java index 0d5c5ea..faa70f2 100644 --- a/core/src/test/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriterSuite.java +++ b/core/src/test/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriterSuite.java @@ -233,6 +233,7 @@ public class UnsafeShuffleWriterSuite { writer.write(Iterators.emptyIterator()); final Option<MapStatus> mapStatus = writer.stop(true); assertTrue(mapStatus.isDefined()); + assertEquals(0, mapStatus.get().numberOfOutput()); assertTrue(mergedOutputFile.exists()); assertArrayEquals(new long[NUM_PARTITITONS], partitionSizesInMergedFile); assertEquals(0, taskMetrics.shuffleWriteMetrics().recordsWritten()); @@ -252,6 +253,7 @@ public class UnsafeShuffleWriterSuite { writer.write(dataToWrite.iterator()); final Option<MapStatus> mapStatus = writer.stop(true); assertTrue(mapStatus.isDefined()); + assertEquals(NUM_PARTITITONS, mapStatus.get().numberOfOutput()); assertTrue(mergedOutputFile.exists()); long sumOfPartitionSizes = 0; http://git-wip-us.apache.org/repos/asf/spark/blob/4f175850/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala ---------------------------------------------------------------------- diff --git a/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala b/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala index 21f481d..e797396 100644 --- a/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala +++ b/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala @@ -62,9 +62,9 @@ class MapOutputTrackerSuite extends SparkFunSuite { val size1000 = MapStatus.decompressSize(MapStatus.compressSize(1000L)) val size10000 = MapStatus.decompressSize(MapStatus.compressSize(10000L)) tracker.registerMapOutput(10, 0, MapStatus(BlockManagerId("a", "hostA", 1000), - Array(1000L, 10000L))) + Array(1000L, 10000L), 10)) tracker.registerMapOutput(10, 1, MapStatus(BlockManagerId("b", "hostB", 1000), - Array(10000L, 1000L))) + Array(10000L, 1000L), 10)) val statuses = tracker.getMapSizesByExecutorId(10, 0) assert(statuses.toSet === Seq((BlockManagerId("a", "hostA", 1000), ArrayBuffer((ShuffleBlockId(10, 0, 0), size1000))), @@ -84,9 +84,9 @@ class MapOutputTrackerSuite extends SparkFunSuite { val compressedSize1000 = MapStatus.compressSize(1000L) val compressedSize10000 = MapStatus.compressSize(10000L) tracker.registerMapOutput(10, 0, MapStatus(BlockManagerId("a", "hostA", 1000), - Array(compressedSize1000, compressedSize10000))) + Array(compressedSize1000, compressedSize10000), 10)) tracker.registerMapOutput(10, 1, MapStatus(BlockManagerId("b", "hostB", 1000), - Array(compressedSize10000, compressedSize1000))) + Array(compressedSize10000, compressedSize1000), 10)) assert(tracker.containsShuffle(10)) assert(tracker.getMapSizesByExecutorId(10, 0).nonEmpty) assert(0 == tracker.getNumCachedSerializedBroadcast) @@ -107,9 +107,9 @@ class MapOutputTrackerSuite extends SparkFunSuite { val compressedSize1000 = MapStatus.compressSize(1000L) val compressedSize10000 = MapStatus.compressSize(10000L) tracker.registerMapOutput(10, 0, MapStatus(BlockManagerId("a", "hostA", 1000), - Array(compressedSize1000, compressedSize1000, compressedSize1000))) + Array(compressedSize1000, compressedSize1000, compressedSize1000), 10)) tracker.registerMapOutput(10, 1, MapStatus(BlockManagerId("b", "hostB", 1000), - Array(compressedSize10000, compressedSize1000, compressedSize1000))) + Array(compressedSize10000, compressedSize1000, compressedSize1000), 10)) assert(0 == tracker.getNumCachedSerializedBroadcast) // As if we had two simultaneous fetch failures @@ -145,7 +145,7 @@ class MapOutputTrackerSuite extends SparkFunSuite { val size1000 = MapStatus.decompressSize(MapStatus.compressSize(1000L)) masterTracker.registerMapOutput(10, 0, MapStatus( - BlockManagerId("a", "hostA", 1000), Array(1000L))) + BlockManagerId("a", "hostA", 1000), Array(1000L), 10)) slaveTracker.updateEpoch(masterTracker.getEpoch) assert(slaveTracker.getMapSizesByExecutorId(10, 0).toSeq === Seq((BlockManagerId("a", "hostA", 1000), ArrayBuffer((ShuffleBlockId(10, 0, 0), size1000))))) @@ -182,7 +182,7 @@ class MapOutputTrackerSuite extends SparkFunSuite { // Message size should be ~123B, and no exception should be thrown masterTracker.registerShuffle(10, 1) masterTracker.registerMapOutput(10, 0, MapStatus( - BlockManagerId("88", "mph", 1000), Array.fill[Long](10)(0))) + BlockManagerId("88", "mph", 1000), Array.fill[Long](10)(0), 0)) val senderAddress = RpcAddress("localhost", 12345) val rpcCallContext = mock(classOf[RpcCallContext]) when(rpcCallContext.senderAddress).thenReturn(senderAddress) @@ -216,11 +216,11 @@ class MapOutputTrackerSuite extends SparkFunSuite { // on hostB with output size 3 tracker.registerShuffle(10, 3) tracker.registerMapOutput(10, 0, MapStatus(BlockManagerId("a", "hostA", 1000), - Array(2L))) + Array(2L), 1)) tracker.registerMapOutput(10, 1, MapStatus(BlockManagerId("a", "hostA", 1000), - Array(2L))) + Array(2L), 1)) tracker.registerMapOutput(10, 2, MapStatus(BlockManagerId("b", "hostB", 1000), - Array(3L))) + Array(3L), 1)) // When the threshold is 50%, only host A should be returned as a preferred location // as it has 4 out of 7 bytes of output. @@ -260,7 +260,7 @@ class MapOutputTrackerSuite extends SparkFunSuite { masterTracker.registerShuffle(20, 100) (0 until 100).foreach { i => masterTracker.registerMapOutput(20, i, new CompressedMapStatus( - BlockManagerId("999", "mps", 1000), Array.fill[Long](4000000)(0))) + BlockManagerId("999", "mps", 1000), Array.fill[Long](4000000)(0), 0)) } val senderAddress = RpcAddress("localhost", 12345) val rpcCallContext = mock(classOf[RpcCallContext]) @@ -309,9 +309,9 @@ class MapOutputTrackerSuite extends SparkFunSuite { val size1000 = MapStatus.decompressSize(MapStatus.compressSize(1000L)) val size10000 = MapStatus.decompressSize(MapStatus.compressSize(10000L)) tracker.registerMapOutput(10, 0, MapStatus(BlockManagerId("a", "hostA", 1000), - Array(size0, size1000, size0, size10000))) + Array(size0, size1000, size0, size10000), 1)) tracker.registerMapOutput(10, 1, MapStatus(BlockManagerId("b", "hostB", 1000), - Array(size10000, size0, size1000, size0))) + Array(size10000, size0, size1000, size0), 1)) assert(tracker.containsShuffle(10)) assert(tracker.getMapSizesByExecutorId(10, 0, 4).toSeq === Seq( http://git-wip-us.apache.org/repos/asf/spark/blob/4f175850/core/src/test/scala/org/apache/spark/ShuffleSuite.scala ---------------------------------------------------------------------- diff --git a/core/src/test/scala/org/apache/spark/ShuffleSuite.scala b/core/src/test/scala/org/apache/spark/ShuffleSuite.scala index ced5a06..d11eaf8 100644 --- a/core/src/test/scala/org/apache/spark/ShuffleSuite.scala +++ b/core/src/test/scala/org/apache/spark/ShuffleSuite.scala @@ -391,6 +391,7 @@ abstract class ShuffleSuite extends SparkFunSuite with Matchers with LocalSparkC assert(mapOutput2.isDefined) assert(mapOutput1.get.location === mapOutput2.get.location) assert(mapOutput1.get.getSizeForBlock(0) === mapOutput1.get.getSizeForBlock(0)) + assert(mapOutput1.get.numberOfOutput === mapOutput2.get.numberOfOutput) // register one of the map outputs -- doesn't matter which one mapOutput1.foreach { case mapStatus => http://git-wip-us.apache.org/repos/asf/spark/blob/4f175850/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala ---------------------------------------------------------------------- diff --git a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala index 211002b..5e095ce 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala @@ -423,17 +423,17 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with TimeLi // map stage1 completes successfully, with one task on each executor complete(taskSets(0), Seq( (Success, - MapStatus(BlockManagerId("exec-hostA1", "hostA", 12345), Array.fill[Long](1)(2))), + MapStatus(BlockManagerId("exec-hostA1", "hostA", 12345), Array.fill[Long](1)(2), 1)), (Success, - MapStatus(BlockManagerId("exec-hostA2", "hostA", 12345), Array.fill[Long](1)(2))), + MapStatus(BlockManagerId("exec-hostA2", "hostA", 12345), Array.fill[Long](1)(2), 1)), (Success, makeMapStatus("hostB", 1)) )) // map stage2 completes successfully, with one task on each executor complete(taskSets(1), Seq( (Success, - MapStatus(BlockManagerId("exec-hostA1", "hostA", 12345), Array.fill[Long](1)(2))), + MapStatus(BlockManagerId("exec-hostA1", "hostA", 12345), Array.fill[Long](1)(2), 1)), (Success, - MapStatus(BlockManagerId("exec-hostA2", "hostA", 12345), Array.fill[Long](1)(2))), + MapStatus(BlockManagerId("exec-hostA2", "hostA", 12345), Array.fill[Long](1)(2), 1)), (Success, makeMapStatus("hostB", 1)) )) // make sure our test setup is correct @@ -2576,7 +2576,7 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with TimeLi object DAGSchedulerSuite { def makeMapStatus(host: String, reduces: Int, sizes: Byte = 2): MapStatus = - MapStatus(makeBlockManagerId(host), Array.fill[Long](reduces)(sizes)) + MapStatus(makeBlockManagerId(host), Array.fill[Long](reduces)(sizes), 1) def makeBlockManagerId(host: String): BlockManagerId = BlockManagerId("exec-" + host, host, 12345) http://git-wip-us.apache.org/repos/asf/spark/blob/4f175850/core/src/test/scala/org/apache/spark/scheduler/MapStatusSuite.scala ---------------------------------------------------------------------- diff --git a/core/src/test/scala/org/apache/spark/scheduler/MapStatusSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/MapStatusSuite.scala index 354e638..555e48b 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/MapStatusSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/MapStatusSuite.scala @@ -60,7 +60,7 @@ class MapStatusSuite extends SparkFunSuite { stddev <- Seq(0.0, 0.01, 0.5, 1.0) ) { val sizes = Array.fill[Long](numSizes)(abs(round(Random.nextGaussian() * stddev)) + mean) - val status = MapStatus(BlockManagerId("a", "b", 10), sizes) + val status = MapStatus(BlockManagerId("a", "b", 10), sizes, 1) val status1 = compressAndDecompressMapStatus(status) for (i <- 0 until numSizes) { if (sizes(i) != 0) { @@ -74,7 +74,7 @@ class MapStatusSuite extends SparkFunSuite { test("large tasks should use " + classOf[HighlyCompressedMapStatus].getName) { val sizes = Array.fill[Long](2001)(150L) - val status = MapStatus(null, sizes) + val status = MapStatus(null, sizes, 1) assert(status.isInstanceOf[HighlyCompressedMapStatus]) assert(status.getSizeForBlock(10) === 150L) assert(status.getSizeForBlock(50) === 150L) @@ -86,7 +86,7 @@ class MapStatusSuite extends SparkFunSuite { val sizes = Array.tabulate[Long](3000) { i => i.toLong } val avg = sizes.sum / sizes.count(_ != 0) val loc = BlockManagerId("a", "b", 10) - val status = MapStatus(loc, sizes) + val status = MapStatus(loc, sizes, 1) val status1 = compressAndDecompressMapStatus(status) assert(status1.isInstanceOf[HighlyCompressedMapStatus]) assert(status1.location == loc) @@ -108,7 +108,7 @@ class MapStatusSuite extends SparkFunSuite { val smallBlockSizes = sizes.filter(n => n > 0 && n < threshold) val avg = smallBlockSizes.sum / smallBlockSizes.length val loc = BlockManagerId("a", "b", 10) - val status = MapStatus(loc, sizes) + val status = MapStatus(loc, sizes, 1) val status1 = compressAndDecompressMapStatus(status) assert(status1.isInstanceOf[HighlyCompressedMapStatus]) assert(status1.location == loc) @@ -164,7 +164,7 @@ class MapStatusSuite extends SparkFunSuite { SparkEnv.set(env) // Value of element in sizes is equal to the corresponding index. val sizes = (0L to 2000L).toArray - val status1 = MapStatus(BlockManagerId("exec-0", "host-0", 100), sizes) + val status1 = MapStatus(BlockManagerId("exec-0", "host-0", 100), sizes, 1) val arrayStream = new ByteArrayOutputStream(102400) val objectOutputStream = new ObjectOutputStream(arrayStream) assert(status1.isInstanceOf[HighlyCompressedMapStatus]) @@ -196,19 +196,19 @@ class MapStatusSuite extends SparkFunSuite { SparkEnv.set(env) val sizes = Array.fill[Long](500)(150L) // Test default value - val status = MapStatus(null, sizes) + val status = MapStatus(null, sizes, 1) assert(status.isInstanceOf[CompressedMapStatus]) // Test Non-positive values for (s <- -1 to 0) { assertThrows[IllegalArgumentException] { conf.set(config.SHUFFLE_MIN_NUM_PARTS_TO_HIGHLY_COMPRESS, s) - val status = MapStatus(null, sizes) + val status = MapStatus(null, sizes, 1) } } // Test positive values Seq(1, 100, 499, 500, 501).foreach { s => conf.set(config.SHUFFLE_MIN_NUM_PARTS_TO_HIGHLY_COMPRESS, s) - val status = MapStatus(null, sizes) + val status = MapStatus(null, sizes, 1) if(sizes.length > s) { assert(status.isInstanceOf[HighlyCompressedMapStatus]) } else { http://git-wip-us.apache.org/repos/asf/spark/blob/4f175850/core/src/test/scala/org/apache/spark/serializer/KryoSerializerSuite.scala ---------------------------------------------------------------------- diff --git a/core/src/test/scala/org/apache/spark/serializer/KryoSerializerSuite.scala b/core/src/test/scala/org/apache/spark/serializer/KryoSerializerSuite.scala index fc78655..240f8cf 100644 --- a/core/src/test/scala/org/apache/spark/serializer/KryoSerializerSuite.scala +++ b/core/src/test/scala/org/apache/spark/serializer/KryoSerializerSuite.scala @@ -345,7 +345,8 @@ class KryoSerializerSuite extends SparkFunSuite with SharedSparkContext { val denseBlockSizes = new Array[Long](5000) val sparseBlockSizes = Array[Long](0L, 1L, 0L, 2L) Seq(denseBlockSizes, sparseBlockSizes).foreach { blockSizes => - ser.serialize(HighlyCompressedMapStatus(BlockManagerId("exec-1", "host", 1234), blockSizes)) + ser.serialize( + HighlyCompressedMapStatus(BlockManagerId("exec-1", "host", 1234), blockSizes, 1)) } } http://git-wip-us.apache.org/repos/asf/spark/blob/4f175850/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala ---------------------------------------------------------------------- diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala index cc1a5e8..cd28c73 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql.catalyst.plans.physical +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.types.{DataType, IntegerType} @@ -207,6 +209,18 @@ case object SinglePartition extends Partitioning { } /** + * Represents a partitioning where rows are only serialized/deserialized locally. The number + * of partitions are not changed and also the distribution of rows. This is mainly used to + * obtain some statistics of map tasks such as number of outputs. + */ +case class LocalPartitioning(childRDD: RDD[InternalRow]) extends Partitioning { + val numPartitions = childRDD.getNumPartitions + + // We will perform this partitioning no matter what the data distribution is. + override def satisfies0(required: Distribution): Boolean = false +} + +/** * Represents a partitioning where rows are split up across partitions based on the hash * of `expressions`. All rows where `expressions` evaluate to the same values are guaranteed to be * in the same partition. http://git-wip-us.apache.org/repos/asf/spark/blob/4f175850/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala ---------------------------------------------------------------------- diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index 979a554..603c070 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -214,6 +214,13 @@ object SQLConf { .intConf .createWithDefault(4) + val LIMIT_FLAT_GLOBAL_LIMIT = buildConf("spark.sql.limit.flatGlobalLimit") + .internal() + .doc("During global limit, try to evenly distribute limited rows across data " + + "partitions. If disabled, scanning data partitions sequentially until reaching limit number.") + .booleanConf + .createWithDefault(true) + val ADVANCED_PARTITION_PREDICATE_PUSHDOWN = buildConf("spark.sql.hive.advancedPartitionPredicatePushdown.enabled") .internal() @@ -1682,6 +1689,8 @@ class SQLConf extends Serializable with Logging { def limitScaleUpFactor: Int = getConf(LIMIT_SCALE_UP_FACTOR) + def limitFlatGlobalLimit: Boolean = getConf(LIMIT_FLAT_GLOBAL_LIMIT) + def advancedPartitionPredicatePushdownEnabled: Boolean = getConf(ADVANCED_PARTITION_PREDICATE_PUSHDOWN) http://git-wip-us.apache.org/repos/asf/spark/blob/4f175850/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ShuffleExchangeExec.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ShuffleExchangeExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ShuffleExchangeExec.scala index b892037..50f10c3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ShuffleExchangeExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ShuffleExchangeExec.scala @@ -231,6 +231,11 @@ object ShuffleExchangeExec { override def numPartitions: Int = 1 override def getPartition(key: Any): Int = 0 } + case l: LocalPartitioning => + new Partitioner { + override def numPartitions: Int = l.numPartitions + override def getPartition(key: Any): Int = key.asInstanceOf[Int] + } case _ => sys.error(s"Exchange not implemented for $newPartitioning") // TODO: Handle BroadcastPartitioning. } @@ -247,6 +252,9 @@ object ShuffleExchangeExec { val projection = UnsafeProjection.create(h.partitionIdExpression :: Nil, outputAttributes) row => projection(row).getInt(0) case RangePartitioning(_, _) | SinglePartition => identity + case _: LocalPartitioning => + val partitionId = TaskContext.get().partitionId() + _ => partitionId case _ => sys.error(s"Exchange not implemented for $newPartitioning") } http://git-wip-us.apache.org/repos/asf/spark/blob/4f175850/sql/core/src/main/scala/org/apache/spark/sql/execution/limit.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/limit.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/limit.scala index 66bcda8..392ca13 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/limit.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/limit.scala @@ -47,13 +47,16 @@ case class CollectLimitExec(limit: Int, child: SparkPlan) extends UnaryExecNode } /** - * Helper trait which defines methods that are shared by both - * [[LocalLimitExec]] and [[GlobalLimitExec]]. + * Take the first `limit` elements of each child partition, but do not collect or shuffle them. */ -trait BaseLimitExec extends UnaryExecNode with CodegenSupport { - val limit: Int +case class LocalLimitExec(limit: Int, child: SparkPlan) extends UnaryExecNode with CodegenSupport { + override def output: Seq[Attribute] = child.output + override def outputOrdering: Seq[SortOrder] = child.outputOrdering + + override def outputPartitioning: Partitioning = child.outputPartitioning + protected override def doExecute(): RDD[InternalRow] = child.execute().mapPartitions { iter => iter.take(limit) } @@ -93,25 +96,93 @@ trait BaseLimitExec extends UnaryExecNode with CodegenSupport { } /** - * Take the first `limit` elements of each child partition, but do not collect or shuffle them. + * Take the `limit` elements of the child output. */ -case class LocalLimitExec(limit: Int, child: SparkPlan) extends BaseLimitExec { +case class GlobalLimitExec(limit: Int, child: SparkPlan) extends UnaryExecNode { - override def outputOrdering: Seq[SortOrder] = child.outputOrdering + override def output: Seq[Attribute] = child.output override def outputPartitioning: Partitioning = child.outputPartitioning -} -/** - * Take the first `limit` elements of the child's single output partition. - */ -case class GlobalLimitExec(limit: Int, child: SparkPlan) extends BaseLimitExec { + override def outputOrdering: Seq[SortOrder] = child.outputOrdering - override def requiredChildDistribution: List[Distribution] = AllTuples :: Nil + private val serializer: Serializer = new UnsafeRowSerializer(child.output.size) - override def outputPartitioning: Partitioning = child.outputPartitioning + protected override def doExecute(): RDD[InternalRow] = { + val childRDD = child.execute() + val partitioner = LocalPartitioning(childRDD) + val shuffleDependency = ShuffleExchangeExec.prepareShuffleDependency( + childRDD, child.output, partitioner, serializer) + val numberOfOutput: Seq[Long] = if (shuffleDependency.rdd.getNumPartitions != 0) { + // submitMapStage does not accept RDD with 0 partition. + // So, we will not submit this dependency. + val submittedStageFuture = sparkContext.submitMapStage(shuffleDependency) + submittedStageFuture.get().recordsByPartitionId.toSeq + } else { + Nil + } - override def outputOrdering: Seq[SortOrder] = child.outputOrdering + // During global limit, try to evenly distribute limited rows across data + // partitions. If disabled, scanning data partitions sequentially until reaching limit number. + // Besides, if child output has certain ordering, we can't evenly pick up rows from + // each parititon. + val flatGlobalLimit = sqlContext.conf.limitFlatGlobalLimit && child.outputOrdering == Nil + + val shuffled = new ShuffledRowRDD(shuffleDependency) + + val sumOfOutput = numberOfOutput.sum + if (sumOfOutput <= limit) { + shuffled + } else if (!flatGlobalLimit) { + var numRowTaken = 0 + val takeAmounts = numberOfOutput.map { num => + if (numRowTaken + num < limit) { + numRowTaken += num.toInt + num.toInt + } else { + val toTake = limit - numRowTaken + numRowTaken += toTake + toTake + } + } + val broadMap = sparkContext.broadcast(takeAmounts) + shuffled.mapPartitionsWithIndexInternal { case (index, iter) => + iter.take(broadMap.value(index).toInt) + } + } else { + // We try to evenly require the asked limit number of rows across all child rdd's partitions. + var rowsNeedToTake: Long = limit + val takeAmountByPartition: Array[Long] = Array.fill[Long](numberOfOutput.length)(0L) + val remainingRowsByPartition: Array[Long] = Array(numberOfOutput: _*) + + while (rowsNeedToTake > 0) { + val nonEmptyParts = remainingRowsByPartition.count(_ > 0) + // If the rows needed to take are less the number of non-empty partitions, take one row from + // each non-empty partitions until we reach `limit` rows. + // Otherwise, evenly divide the needed rows to each non-empty partitions. + val takePerPart = math.max(1, rowsNeedToTake / nonEmptyParts) + remainingRowsByPartition.zipWithIndex.foreach { case (num, index) => + // In case `rowsNeedToTake` < `nonEmptyParts`, we may run out of `rowsNeedToTake` during + // the traversal, so we need to add this check. + if (rowsNeedToTake > 0 && num > 0) { + if (num >= takePerPart) { + rowsNeedToTake -= takePerPart + takeAmountByPartition(index) += takePerPart + remainingRowsByPartition(index) -= takePerPart + } else { + rowsNeedToTake -= num + takeAmountByPartition(index) += num + remainingRowsByPartition(index) -= num + } + } + } + } + val broadMap = sparkContext.broadcast(takeAmountByPartition) + shuffled.mapPartitionsWithIndexInternal { case (index, iter) => + iter.take(broadMap.value(index).toInt) + } + } + } } /** http://git-wip-us.apache.org/repos/asf/spark/blob/4f175850/sql/core/src/test/resources/sql-tests/inputs/limit.sql ---------------------------------------------------------------------- diff --git a/sql/core/src/test/resources/sql-tests/inputs/limit.sql b/sql/core/src/test/resources/sql-tests/inputs/limit.sql index b4c73cf..e33cd81 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/limit.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/limit.sql @@ -1,3 +1,5 @@ +-- Disable global limit parallel +set spark.sql.limit.flatGlobalLimit=false; -- limit on various data types SELECT * FROM testdata LIMIT 2; http://git-wip-us.apache.org/repos/asf/spark/blob/4f175850/sql/core/src/test/resources/sql-tests/inputs/subquery/in-subquery/in-limit.sql ---------------------------------------------------------------------- diff --git a/sql/core/src/test/resources/sql-tests/inputs/subquery/in-subquery/in-limit.sql b/sql/core/src/test/resources/sql-tests/inputs/subquery/in-subquery/in-limit.sql index a40ee08..a862e09 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/subquery/in-subquery/in-limit.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/subquery/in-subquery/in-limit.sql @@ -1,6 +1,9 @@ -- A test suite for IN LIMIT in parent side, subquery, and both predicate subquery -- It includes correlated cases. +-- Disable global limit optimization +set spark.sql.limit.flatGlobalLimit=false; + create temporary view t1 as select * from values ("val1a", 6S, 8, 10L, float(15.0), 20D, 20E2, timestamp '2014-04-04 01:00:00.000', date '2014-04-04'), ("val1b", 8S, 16, 19L, float(17.0), 25D, 26E2, timestamp '2014-05-04 01:01:00.000', date '2014-05-04'), @@ -97,4 +100,4 @@ WHERE t1d NOT IN (SELECT t2d LIMIT 1) GROUP BY t1b ORDER BY t1b NULLS last -LIMIT 1; \ No newline at end of file +LIMIT 1; http://git-wip-us.apache.org/repos/asf/spark/blob/4f175850/sql/core/src/test/resources/sql-tests/results/limit.sql.out ---------------------------------------------------------------------- diff --git a/sql/core/src/test/resources/sql-tests/results/limit.sql.out b/sql/core/src/test/resources/sql-tests/results/limit.sql.out index 02fe1de..187f3bd 100644 --- a/sql/core/src/test/resources/sql-tests/results/limit.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/limit.sql.out @@ -1,63 +1,62 @@ -- Automatically generated by SQLQueryTestSuite --- Number of queries: 14 +-- Number of queries: 15 -- !query 0 -SELECT * FROM testdata LIMIT 2 +set spark.sql.limit.flatGlobalLimit=false -- !query 0 schema -struct<key:int,value:string> +struct<key:string,value:string> -- !query 0 output -1 1 -2 2 +spark.sql.limit.flatGlobalLimit false -- !query 1 -SELECT * FROM arraydata LIMIT 2 +SELECT * FROM testdata LIMIT 2 -- !query 1 schema -struct<arraycol:array<int>,nestedarraycol:array<array<int>>> +struct<key:int,value:string> -- !query 1 output -[1,2,3] [[1,2,3]] -[2,3,4] [[2,3,4]] +1 1 +2 2 -- !query 2 -SELECT * FROM mapdata LIMIT 2 +SELECT * FROM arraydata LIMIT 2 -- !query 2 schema -struct<mapcol:map<int,string>> +struct<arraycol:array<int>,nestedarraycol:array<array<int>>> -- !query 2 output -{1:"a1",2:"b1",3:"c1",4:"d1",5:"e1"} -{1:"a2",2:"b2",3:"c2",4:"d2"} +[1,2,3] [[1,2,3]] +[2,3,4] [[2,3,4]] -- !query 3 -SELECT * FROM testdata LIMIT 2 + 1 +SELECT * FROM mapdata LIMIT 2 -- !query 3 schema -struct<key:int,value:string> +struct<mapcol:map<int,string>> -- !query 3 output -1 1 -2 2 -3 3 +{1:"a1",2:"b1",3:"c1",4:"d1",5:"e1"} +{1:"a2",2:"b2",3:"c2",4:"d2"} -- !query 4 -SELECT * FROM testdata LIMIT CAST(1 AS int) +SELECT * FROM testdata LIMIT 2 + 1 -- !query 4 schema struct<key:int,value:string> -- !query 4 output 1 1 +2 2 +3 3 -- !query 5 -SELECT * FROM testdata LIMIT -1 +SELECT * FROM testdata LIMIT CAST(1 AS int) -- !query 5 schema -struct<> +struct<key:int,value:string> -- !query 5 output -org.apache.spark.sql.AnalysisException -The limit expression must be equal to or greater than 0, but got -1; +1 1 -- !query 6 -SELECT * FROM testData TABLESAMPLE (-1 ROWS) +SELECT * FROM testdata LIMIT -1 -- !query 6 schema struct<> -- !query 6 output @@ -66,61 +65,70 @@ The limit expression must be equal to or greater than 0, but got -1; -- !query 7 -SELECT * FROM testdata LIMIT CAST(1 AS INT) +SELECT * FROM testData TABLESAMPLE (-1 ROWS) -- !query 7 schema -struct<key:int,value:string> +struct<> -- !query 7 output -1 1 +org.apache.spark.sql.AnalysisException +The limit expression must be equal to or greater than 0, but got -1; -- !query 8 -SELECT * FROM testdata LIMIT CAST(NULL AS INT) +SELECT * FROM testdata LIMIT CAST(1 AS INT) -- !query 8 schema -struct<> +struct<key:int,value:string> -- !query 8 output -org.apache.spark.sql.AnalysisException -The evaluated limit expression must not be null, but got CAST(NULL AS INT); +1 1 -- !query 9 -SELECT * FROM testdata LIMIT key > 3 +SELECT * FROM testdata LIMIT CAST(NULL AS INT) -- !query 9 schema struct<> -- !query 9 output org.apache.spark.sql.AnalysisException -The limit expression must evaluate to a constant value, but got (testdata.`key` > 3); +The evaluated limit expression must not be null, but got CAST(NULL AS INT); -- !query 10 -SELECT * FROM testdata LIMIT true +SELECT * FROM testdata LIMIT key > 3 -- !query 10 schema struct<> -- !query 10 output org.apache.spark.sql.AnalysisException -The limit expression must be integer type, but got boolean; +The limit expression must evaluate to a constant value, but got (testdata.`key` > 3); -- !query 11 -SELECT * FROM testdata LIMIT 'a' +SELECT * FROM testdata LIMIT true -- !query 11 schema struct<> -- !query 11 output org.apache.spark.sql.AnalysisException -The limit expression must be integer type, but got string; +The limit expression must be integer type, but got boolean; -- !query 12 -SELECT * FROM (SELECT * FROM range(10) LIMIT 5) WHERE id > 3 +SELECT * FROM testdata LIMIT 'a' -- !query 12 schema -struct<id:bigint> +struct<> -- !query 12 output -4 +org.apache.spark.sql.AnalysisException +The limit expression must be integer type, but got string; -- !query 13 -SELECT * FROM testdata WHERE key < 3 LIMIT ALL +SELECT * FROM (SELECT * FROM range(10) LIMIT 5) WHERE id > 3 -- !query 13 schema -struct<key:int,value:string> +struct<id:bigint> -- !query 13 output +4 + + +-- !query 14 +SELECT * FROM testdata WHERE key < 3 LIMIT ALL +-- !query 14 schema +struct<key:int,value:string> +-- !query 14 output 1 1 2 2 http://git-wip-us.apache.org/repos/asf/spark/blob/4f175850/sql/core/src/test/resources/sql-tests/results/subquery/in-subquery/in-limit.sql.out ---------------------------------------------------------------------- diff --git a/sql/core/src/test/resources/sql-tests/results/subquery/in-subquery/in-limit.sql.out b/sql/core/src/test/resources/sql-tests/results/subquery/in-subquery/in-limit.sql.out index 71ca1f8..9eb5b33 100644 --- a/sql/core/src/test/resources/sql-tests/results/subquery/in-subquery/in-limit.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/subquery/in-subquery/in-limit.sql.out @@ -1,8 +1,16 @@ -- Automatically generated by SQLQueryTestSuite --- Number of queries: 8 +-- Number of queries: 9 -- !query 0 +set spark.sql.limit.flatGlobalLimit=false +-- !query 0 schema +struct<key:string,value:string> +-- !query 0 output +spark.sql.limit.flatGlobalLimit false + + +-- !query 1 create temporary view t1 as select * from values ("val1a", 6S, 8, 10L, float(15.0), 20D, 20E2, timestamp '2014-04-04 01:00:00.000', date '2014-04-04'), ("val1b", 8S, 16, 19L, float(17.0), 25D, 26E2, timestamp '2014-05-04 01:01:00.000', date '2014-05-04'), @@ -17,13 +25,13 @@ create temporary view t1 as select * from values ("val1a", 6S, 8, 10L, float(15.0), 20D, 20E2, timestamp '2014-04-04 01:02:00.001', date '2014-04-04'), ("val1e", 10S, null, 19L, float(17.0), 25D, 26E2, timestamp '2014-05-04 01:01:00.000', date '2014-05-04') as t1(t1a, t1b, t1c, t1d, t1e, t1f, t1g, t1h, t1i) --- !query 0 schema +-- !query 1 schema struct<> --- !query 0 output +-- !query 1 output --- !query 1 +-- !query 2 create temporary view t2 as select * from values ("val2a", 6S, 12, 14L, float(15), 20D, 20E2, timestamp '2014-04-04 01:01:00.000', date '2014-04-04'), ("val1b", 10S, 12, 19L, float(17), 25D, 26E2, timestamp '2014-05-04 01:01:00.000', date '2014-05-04'), @@ -39,13 +47,13 @@ create temporary view t2 as select * from values ("val1f", 19S, null, 19L, float(17), 25D, 26E2, timestamp '2014-10-04 01:01:00.000', date '2014-10-04'), ("val1b", null, 16, 19L, float(17), 25D, 26E2, timestamp '2014-05-04 01:01:00.000', null) as t2(t2a, t2b, t2c, t2d, t2e, t2f, t2g, t2h, t2i) --- !query 1 schema +-- !query 2 schema struct<> --- !query 1 output +-- !query 2 output --- !query 2 +-- !query 3 create temporary view t3 as select * from values ("val3a", 6S, 12, 110L, float(15), 20D, 20E2, timestamp '2014-04-04 01:02:00.000', date '2014-04-04'), ("val3a", 6S, 12, 10L, float(15), 20D, 20E2, timestamp '2014-05-04 01:02:00.000', date '2014-05-04'), @@ -60,27 +68,27 @@ create temporary view t3 as select * from values ("val3b", 8S, null, 719L, float(17), 25D, 26E2, timestamp '2014-05-04 01:02:00.000', date '2014-05-04'), ("val3b", 8S, null, 19L, float(17), 25D, 26E2, timestamp '2015-05-04 01:02:00.000', date '2015-05-04') as t3(t3a, t3b, t3c, t3d, t3e, t3f, t3g, t3h, t3i) --- !query 2 schema +-- !query 3 schema struct<> --- !query 2 output +-- !query 3 output --- !query 3 +-- !query 4 SELECT * FROM t1 WHERE t1a IN (SELECT t2a FROM t2 WHERE t1d = t2d) LIMIT 2 --- !query 3 schema +-- !query 4 schema struct<t1a:string,t1b:smallint,t1c:int,t1d:bigint,t1e:float,t1f:double,t1g:decimal(2,-2),t1h:timestamp,t1i:date> --- !query 3 output +-- !query 4 output val1b 8 16 19 17.0 25.0 2600 2014-05-04 01:01:00 2014-05-04 val1c 8 16 19 17.0 25.0 2600 2014-05-04 01:02:00.001 2014-05-05 --- !query 4 +-- !query 5 SELECT * FROM t1 WHERE t1c IN (SELECT t2c @@ -88,16 +96,16 @@ WHERE t1c IN (SELECT t2c WHERE t2b >= 8 LIMIT 2) LIMIT 4 --- !query 4 schema +-- !query 5 schema struct<t1a:string,t1b:smallint,t1c:int,t1d:bigint,t1e:float,t1f:double,t1g:decimal(2,-2),t1h:timestamp,t1i:date> --- !query 4 output +-- !query 5 output val1a 16 12 10 15.0 20.0 2000 2014-07-04 01:01:00 2014-07-04 val1a 16 12 21 15.0 20.0 2000 2014-06-04 01:02:00.001 2014-06-04 val1b 8 16 19 17.0 25.0 2600 2014-05-04 01:01:00 2014-05-04 val1c 8 16 19 17.0 25.0 2600 2014-05-04 01:02:00.001 2014-05-05 --- !query 5 +-- !query 6 SELECT Count(DISTINCT( t1a )), t1b FROM t1 @@ -108,29 +116,29 @@ WHERE t1d IN (SELECT t2d GROUP BY t1b ORDER BY t1b DESC NULLS FIRST LIMIT 1 --- !query 5 schema +-- !query 6 schema struct<count(DISTINCT t1a):bigint,t1b:smallint> --- !query 5 output +-- !query 6 output 1 NULL --- !query 6 +-- !query 7 SELECT * FROM t1 WHERE t1b NOT IN (SELECT t2b FROM t2 WHERE t2b > 6 LIMIT 2) --- !query 6 schema +-- !query 7 schema struct<t1a:string,t1b:smallint,t1c:int,t1d:bigint,t1e:float,t1f:double,t1g:decimal(2,-2),t1h:timestamp,t1i:date> --- !query 6 output +-- !query 7 output val1a 16 12 10 15.0 20.0 2000 2014-07-04 01:01:00 2014-07-04 val1a 16 12 21 15.0 20.0 2000 2014-06-04 01:02:00.001 2014-06-04 val1a 6 8 10 15.0 20.0 2000 2014-04-04 01:00:00 2014-04-04 val1a 6 8 10 15.0 20.0 2000 2014-04-04 01:02:00.001 2014-04-04 --- !query 7 +-- !query 8 SELECT Count(DISTINCT( t1a )), t1b FROM t1 @@ -141,7 +149,7 @@ WHERE t1d NOT IN (SELECT t2d GROUP BY t1b ORDER BY t1b NULLS last LIMIT 1 --- !query 7 schema +-- !query 8 schema struct<count(DISTINCT t1a):bigint,t1b:smallint> --- !query 7 output +-- !query 8 output 1 6 http://git-wip-us.apache.org/repos/asf/spark/blob/4f175850/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala index d0106c4..85b3ca1 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala @@ -557,11 +557,13 @@ class DataFrameAggregateSuite extends QueryTest with SharedSQLContext { } test("SPARK-18004 limit + aggregates") { - val df = Seq(("a", 1), ("b", 2), ("c", 1), ("d", 5)).toDF("id", "value") - val limit2Df = df.limit(2) - checkAnswer( - limit2Df.groupBy("id").count().select($"id"), - limit2Df.select($"id")) + withSQLConf(SQLConf.LIMIT_FLAT_GLOBAL_LIMIT.key -> "true") { + val df = Seq(("a", 1), ("b", 2), ("c", 1), ("d", 5)).toDF("id", "value") + val limit2Df = df.limit(2) + checkAnswer( + limit2Df.groupBy("id").count().select($"id"), + limit2Df.select($"id")) + } } test("SPARK-17237 remove backticks in a pivot result schema") { http://git-wip-us.apache.org/repos/asf/spark/blob/4f175850/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index 3a393d7..c1a5f50 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -524,6 +524,15 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { sortTest() } + test("limit for skew dataframe") { + // Create a skew dataframe. + val df = testData.repartition(100).union(testData).limit(50) + // Because `rdd` of dataframe will add a `DeserializeToObject` on top of `GlobalLimit`, + // the `GlobalLimit` will not be replaced with `CollectLimit`. So we can test if `GlobalLimit` + // work on skew partitions. + assert(df.rdd.count() == 50L) + } + test("CTE feature") { checkAnswer( sql("with q1 as (select * from testData limit 10) select * from q1"), @@ -1935,7 +1944,7 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { // TODO: support subexpression elimination in whole stage codegen withSQLConf("spark.sql.codegen.wholeStage" -> "false") { // select from a table to prevent constant folding. - val df = sql("SELECT a, b from testData2 limit 1") + val df = sql("SELECT a, b from testData2 order by a, b limit 1") checkAnswer(df, Row(1, 1)) checkAnswer(df.selectExpr("a + 1", "a + 1"), Row(2, 2)) http://git-wip-us.apache.org/repos/asf/spark/blob/4f175850/sql/core/src/test/scala/org/apache/spark/sql/execution/ExchangeCoordinatorSuite.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/ExchangeCoordinatorSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/ExchangeCoordinatorSuite.scala index b736d43..41de731 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/ExchangeCoordinatorSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/ExchangeCoordinatorSuite.scala @@ -50,7 +50,7 @@ class ExchangeCoordinatorSuite extends SparkFunSuite with BeforeAndAfterAll { expectedPartitionStartIndices: Array[Int]): Unit = { val mapOutputStatistics = bytesByPartitionIdArray.zipWithIndex.map { case (bytesByPartitionId, index) => - new MapOutputStatistics(index, bytesByPartitionId) + new MapOutputStatistics(index, bytesByPartitionId, Array[Long](1)) } val estimatedPartitionStartIndices = coordinator.estimatePartitionStartIndices(mapOutputStatistics) @@ -114,8 +114,8 @@ class ExchangeCoordinatorSuite extends SparkFunSuite with BeforeAndAfterAll { val bytesByPartitionId2 = Array[Long](0, 0, 0, 0, 0, 0) val mapOutputStatistics = Array( - new MapOutputStatistics(0, bytesByPartitionId1), - new MapOutputStatistics(1, bytesByPartitionId2)) + new MapOutputStatistics(0, bytesByPartitionId1, Array[Long](0)), + new MapOutputStatistics(1, bytesByPartitionId2, Array[Long](0))) intercept[AssertionError](coordinator.estimatePartitionStartIndices(mapOutputStatistics)) } http://git-wip-us.apache.org/repos/asf/spark/blob/4f175850/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala index bdc1063..3db89ec 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala @@ -262,7 +262,7 @@ class PlannerSuite extends SharedSQLContext { ).queryExecution.executedPlan.collect { case exchange: ShuffleExchangeExec => exchange }.length - assert(numExchanges === 5) + assert(numExchanges === 3) } { @@ -277,7 +277,7 @@ class PlannerSuite extends SharedSQLContext { ).queryExecution.executedPlan.collect { case exchange: ShuffleExchangeExec => exchange }.length - assert(numExchanges === 5) + assert(numExchanges === 3) } } http://git-wip-us.apache.org/repos/asf/spark/blob/4f175850/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala ---------------------------------------------------------------------- diff --git a/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala b/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala index cebaad5..b9b2b7d 100644 --- a/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala +++ b/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala @@ -40,6 +40,7 @@ class HiveCompatibilitySuite extends HiveQueryFileTest with BeforeAndAfter { private val originalColumnBatchSize = TestHive.conf.columnBatchSize private val originalInMemoryPartitionPruning = TestHive.conf.inMemoryPartitionPruning private val originalCrossJoinEnabled = TestHive.conf.crossJoinEnabled + private val originalLimitFlatGlobalLimit = TestHive.conf.limitFlatGlobalLimit private val originalSessionLocalTimeZone = TestHive.conf.sessionLocalTimeZone def testCases: Seq[(String, File)] = { @@ -59,6 +60,8 @@ class HiveCompatibilitySuite extends HiveQueryFileTest with BeforeAndAfter { TestHive.setConf(SQLConf.IN_MEMORY_PARTITION_PRUNING, true) // Ensures that cross joins are enabled so that we can test them TestHive.setConf(SQLConf.CROSS_JOINS_ENABLED, true) + // Ensure that limit operation returns rows in the same order as Hive + TestHive.setConf(SQLConf.LIMIT_FLAT_GLOBAL_LIMIT, false) // Fix session local timezone to America/Los_Angeles for those timezone sensitive tests // (timestamp_*) TestHive.setConf(SQLConf.SESSION_LOCAL_TIMEZONE, "America/Los_Angeles") @@ -73,6 +76,7 @@ class HiveCompatibilitySuite extends HiveQueryFileTest with BeforeAndAfter { TestHive.setConf(SQLConf.COLUMN_BATCH_SIZE, originalColumnBatchSize) TestHive.setConf(SQLConf.IN_MEMORY_PARTITION_PRUNING, originalInMemoryPartitionPruning) TestHive.setConf(SQLConf.CROSS_JOINS_ENABLED, originalCrossJoinEnabled) + TestHive.setConf(SQLConf.LIMIT_FLAT_GLOBAL_LIMIT, originalLimitFlatGlobalLimit) TestHive.setConf(SQLConf.SESSION_LOCAL_TIMEZONE, originalSessionLocalTimeZone) // For debugging dump some statistics about how much time was spent in various optimizer rules http://git-wip-us.apache.org/repos/asf/spark/blob/4f175850/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/PruningSuite.scala ---------------------------------------------------------------------- diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/PruningSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/PruningSuite.scala index cc592cf..1654129 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/PruningSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/PruningSuite.scala @@ -22,21 +22,29 @@ import scala.collection.JavaConverters._ import org.scalatest.BeforeAndAfter import org.apache.spark.sql.hive.test.{TestHive, TestHiveQueryExecution} +import org.apache.spark.sql.internal.SQLConf /** * A set of test cases that validate partition and column pruning. */ class PruningSuite extends HiveComparisonTest with BeforeAndAfter { + private val originalLimitFlatGlobalLimit = TestHive.conf.limitFlatGlobalLimit + override def beforeAll(): Unit = { super.beforeAll() TestHive.setCacheTables(false) + TestHive.setConf(SQLConf.LIMIT_FLAT_GLOBAL_LIMIT, false) // Column/partition pruning is not implemented for `InMemoryColumnarTableScan` yet, // need to reset the environment to ensure all referenced tables in this suites are // not cached in-memory. Refer to https://issues.apache.org/jira/browse/SPARK-2283 // for details. TestHive.reset() } + override def afterAll() { + TestHive.setConf(SQLConf.LIMIT_FLAT_GLOBAL_LIMIT, originalLimitFlatGlobalLimit) + super.afterAll() + } // Column pruning tests --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org