Github user tdas commented on a diff in the pull request: https://github.com/apache/spark/pull/19196#discussion_r138762394 --- Diff: sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingAggregationSuite.scala --- @@ -381,4 +388,233 @@ class StreamingAggregationSuite extends StateStoreMetricsTest AddData(streamInput, 0, 1, 2, 3), CheckLastBatch((0, 0, 2), (1, 1, 3))) } + + private def checkAggregationChain( + sq: StreamingQuery, + requiresShuffling: Boolean, + expectedPartition: Int): Unit = { + val executedPlan = sq.asInstanceOf[StreamingQueryWrapper].streamingQuery + .lastExecution.executedPlan + val restore = executedPlan + .collect { case ss: StateStoreRestoreExec => ss } + .head + restore.child match { + case node: UnaryExecNode => + assert(node.outputPartitioning.numPartitions === expectedPartition) + if (requiresShuffling) { + assert(node.isInstanceOf[Exchange], s"Expected a shuffle, got: ${node.child}") + } else { + assert(!node.isInstanceOf[Exchange], "Didn't expect a shuffle") + } + + case _ => fail("Expected no shuffling") + } + var reachedRestore = false + // Check that there should be no exchanges after `StateStoreRestoreExec` + executedPlan.foreachUp { p => + if (reachedRestore) { + assert(!p.isInstanceOf[Exchange], "There should be no further exchanges") + } else { + reachedRestore = p.isInstanceOf[StateStoreRestoreExec] + } + } + } + + test("SPARK-21977: coalesce(1) with 0 partition RDD should be repartitioned accordingly") { + val inputSource = new NonLocalRelationSource(spark) + MockSourceProvider.withMockSources(inputSource) { + withTempDir { tempDir => + val aggregated: Dataset[Long] = + spark.readStream + .format((new MockSourceProvider).getClass.getCanonicalName) + .load() + .coalesce(1) + .groupBy() + .count() + .as[Long] + + val sq = aggregated.writeStream + .format("memory") + .outputMode("complete") + .queryName("agg_test") + .option("checkpointLocation", tempDir.getAbsolutePath) + .start() + + try { + + inputSource.addData(1) + inputSource.releaseLock() + sq.processAllAvailable() + + checkDataset( + spark.table("agg_test").as[Long], + 1L) + + checkAggregationChain(sq, requiresShuffling = false, 1) + + inputSource.addData() + inputSource.releaseLock() + sq.processAllAvailable() + + checkAggregationChain(sq, requiresShuffling = true, 1) + + checkDataset( + spark.table("agg_test").as[Long], + 1L) + + inputSource.addData(2, 3) + inputSource.releaseLock() + sq.processAllAvailable() + + checkDataset( + spark.table("agg_test").as[Long], + 3L) + + inputSource.addData() + inputSource.releaseLock() + sq.processAllAvailable() + + checkDataset( + spark.table("agg_test").as[Long], + 3L) + } finally { + sq.stop() + } + } + } + } + + test("SPARK-21977: coalesce(1) should still be repartitioned when it has keyExpressions") { + val inputSource = new NonLocalRelationSource(spark) + MockSourceProvider.withMockSources(inputSource) { + withTempDir { tempDir => + + val sq = spark.readStream + .format((new MockSourceProvider).getClass.getCanonicalName) + .load() + .coalesce(1) + .groupBy('a % 1) // just to give it a fake key + .count() + .as[(Long, Long)] + .writeStream + .format("memory") + .outputMode("complete") + .queryName("agg_test") + .option("checkpointLocation", tempDir.getAbsolutePath) + .start() + + try { + + inputSource.addData(1) + inputSource.releaseLock() + sq.processAllAvailable() + + checkAggregationChain( + sq, + requiresShuffling = true, + spark.sessionState.conf.numShufflePartitions) + + checkDataset( + spark.table("agg_test").as[(Long, Long)], + (0L, 1L)) + + } finally { + sq.stop() + } + + val sq2 = spark.readStream + .format((new MockSourceProvider).getClass.getCanonicalName) + .load() + .coalesce(2) + .groupBy('a % 1) // just to give it a fake key + .count() + .as[(Long, Long)] + .writeStream + .format("memory") + .outputMode("complete") + .queryName("agg_test") + .option("checkpointLocation", tempDir.getAbsolutePath) + .start() + + try { + sq2.processAllAvailable() + inputSource.addData(2) + inputSource.addData(3) + inputSource.addData(4) + inputSource.releaseLock() + sq2.processAllAvailable() + + checkAggregationChain( + sq2, + requiresShuffling = false, // doesn't require extra shuffle as HashAggregate adds it + spark.sessionState.conf.numShufflePartitions) + + checkDataset( + spark.table("agg_test").as[(Long, Long)], + (0L, 4L)) + + inputSource.addData() + inputSource.releaseLock() + sq2.processAllAvailable() + + checkDataset( + spark.table("agg_test").as[(Long, Long)], + (0L, 4L)) + } finally { + sq2.stop() + } + } + } + } +} + +/** + * LocalRelation has some optimized properties during Spark planning. In order for the bugs in + * SPARK-21977 to occur, we need to create a logical relation from an existing RDD. We use a + * BlockRDD since it accepts 0 partitions. One requirement for the one of the bugs is the use of + * `coalesce(1)`, which has several optimizations regarding [[SinglePartition]], and a 0 partition + * parentRDD. + */ +class NonLocalRelationSource(spark: SparkSession) extends Source { --- End diff -- The docs should explain accordingly, what it does, not why it does it the way it is. It really does not matter that local relation is not the right thing to use.
--- --------------------------------------------------------------------- To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org For additional commands, e-mail: reviews-h...@spark.apache.org