Github user tdas commented on a diff in the pull request: https://github.com/apache/spark/pull/19196#discussion_r139078823 --- Diff: sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingAggregationSuite.scala --- @@ -381,4 +388,187 @@ class StreamingAggregationSuite extends StateStoreMetricsTest AddData(streamInput, 0, 1, 2, 3), CheckLastBatch((0, 0, 2), (1, 1, 3))) } + + /** + * This method verifies certain properties in the SparkPlan of a streaming aggregation. + * First of all, it checks that the child of a `StateStoreRestoreExec` creates the desired + * data distribution, where the child could be an Exchange, or a `HashAggregateExec` which already + * provides the expected data distribution. + * + * The second thing it checks that the child provides the expected number of partitions. + * + * The third thing it checks that we don't add an unnecessary shuffle in-between + * `StateStoreRestoreExec` and `StateStoreSaveExec`. + */ + private def checkAggregationChain( + se: StreamExecution, + expectShuffling: Boolean, + expectedPartition: Int): Boolean = { + val executedPlan = se.lastExecution.executedPlan + val restore = executedPlan + .collect { case ss: StateStoreRestoreExec => ss } + .head + restore.child match { + case node: UnaryExecNode => + assert(node.outputPartitioning.numPartitions === expectedPartition, + "Didn't get the expected number of partitions.") + if (expectShuffling) { + assert(node.isInstanceOf[Exchange], s"Expected a shuffle, got: ${node.child}") + } else { + assert(!node.isInstanceOf[Exchange], "Didn't expect a shuffle") + } + + case _ => fail("Expected no shuffling") + } + var reachedRestore = false + // Check that there should be no exchanges after `StateStoreRestoreExec` + executedPlan.foreachUp { p => + if (reachedRestore) { + assert(!p.isInstanceOf[Exchange], "There should be no further exchanges") + } else { + reachedRestore = p.isInstanceOf[StateStoreRestoreExec] + } + } + true + } + + /** Add blocks of data to the `BlockRDDBackedSource`. */ + case class AddBlockData(source: BlockRDDBackedSource, data: Seq[Int]*) extends AddData { + override def addData(query: Option[StreamExecution]): (Source, Offset) = { + if (data.nonEmpty) { + data.foreach(source.addData) + } else { + // we would like to create empty blockRDD's so add an empty block here. + source.addData() + } + source.releaseLock() + (source, LongOffset(source.counter)) + } + } + + test("SPARK-21977: coalesce(1) with 0 partition RDD should be repartitioned to 1") { + val inputSource = new BlockRDDBackedSource(spark) + MockSourceProvider.withMockSources(inputSource) { + withTempDir { tempDir => + val aggregated: Dataset[Long] = + spark.readStream + .format((new MockSourceProvider).getClass.getCanonicalName) + .load() + .coalesce(1) + .groupBy() + .count() + .as[Long] + + testStream(aggregated, Complete())( + AddBlockData(inputSource, Seq(1)), + CheckLastBatch(1), + AssertOnQuery("Verify no shuffling") { se => + checkAggregationChain(se, expectShuffling = false, 1) + }, + AddBlockData(inputSource), // create an empty trigger + CheckLastBatch(1), + AssertOnQuery("Verify addition of exchange operator") { se => + checkAggregationChain(se, expectShuffling = true, 1) + }, + AddBlockData(inputSource, Seq(2, 3)), + CheckLastBatch(3), + AddBlockData(inputSource), + CheckLastBatch(3), + StopStream + ) + } + } + } + + test("SPARK-21977: coalesce(1) should still be repartitioned when it has keyExpressions") { + val inputSource = new BlockRDDBackedSource(spark) + MockSourceProvider.withMockSources(inputSource) { + withTempDir { tempDir => + + def createDf(partitions: Int): Dataset[(Long, Long)] = { + spark.readStream + .format((new MockSourceProvider).getClass.getCanonicalName) + .load() + .coalesce(partitions) + .groupBy('a % 1) // just to give it a fake key + .count() + .as[(Long, Long)] + } + + testStream(createDf(1), Complete())( + StartStream(checkpointLocation = tempDir.getAbsolutePath), + AddBlockData(inputSource, Seq(1)), + CheckLastBatch((0L, 1L)), + AssertOnQuery("Verify addition of exchange operator") { se => + checkAggregationChain( + se, + expectShuffling = true, + spark.sessionState.conf.numShufflePartitions) + }, + StopStream + ) + + testStream(createDf(2), Complete())( + StartStream(checkpointLocation = tempDir.getAbsolutePath), + Execute(se => se.processAllAvailable()), + AddBlockData(inputSource, Seq(2), Seq(3), Seq(4)), + CheckLastBatch((0L, 4L)), + AssertOnQuery("Verify no exchange added") { se => + checkAggregationChain( + se, + expectShuffling = false, + spark.sessionState.conf.numShufflePartitions) + }, + AddBlockData(inputSource), + CheckLastBatch((0L, 4L)), + StopStream + ) + } + } + } +} + +/** + * A Streaming Source that is backed by a BlockRDD and that can create RDDs with 0 blocks at will. + */ +class BlockRDDBackedSource(spark: SparkSession) extends Source { + var counter = 0L + private val blockMgr = SparkEnv.get.blockManager + private var blocks: Seq[BlockId] = Seq.empty + + private var streamLock: CountDownLatch = new CountDownLatch(1) + + def addData(data: Int*): Unit = { + if (streamLock.getCount == 0) { + streamLock = new CountDownLatch(1) --- End diff -- This is complicated. See how AddFileData is implemented. It's much simpler.
--- --------------------------------------------------------------------- To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org For additional commands, e-mail: reviews-h...@spark.apache.org