Github user tdas commented on a diff in the pull request:

    https://github.com/apache/spark/pull/19196#discussion_r139078823
  
    --- Diff: 
sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingAggregationSuite.scala
 ---
    @@ -381,4 +388,187 @@ class StreamingAggregationSuite extends 
StateStoreMetricsTest
           AddData(streamInput, 0, 1, 2, 3),
           CheckLastBatch((0, 0, 2), (1, 1, 3)))
       }
    +
    +  /**
    +   * This method verifies certain properties in the SparkPlan of a 
streaming aggregation.
    +   * First of all, it checks that the child of a `StateStoreRestoreExec` 
creates the desired
    +   * data distribution, where the child could be an Exchange, or a 
`HashAggregateExec` which already
    +   * provides the expected data distribution.
    +   *
    +   * The second thing it checks that the child provides the expected 
number of partitions.
    +   *
    +   * The third thing it checks that we don't add an unnecessary shuffle 
in-between
    +   * `StateStoreRestoreExec` and `StateStoreSaveExec`.
    +   */
    +  private def checkAggregationChain(
    +      se: StreamExecution,
    +      expectShuffling: Boolean,
    +      expectedPartition: Int): Boolean = {
    +    val executedPlan = se.lastExecution.executedPlan
    +    val restore = executedPlan
    +      .collect { case ss: StateStoreRestoreExec => ss }
    +      .head
    +    restore.child match {
    +      case node: UnaryExecNode =>
    +        assert(node.outputPartitioning.numPartitions === expectedPartition,
    +          "Didn't get the expected number of partitions.")
    +        if (expectShuffling) {
    +          assert(node.isInstanceOf[Exchange], s"Expected a shuffle, got: 
${node.child}")
    +        } else {
    +          assert(!node.isInstanceOf[Exchange], "Didn't expect a shuffle")
    +        }
    +
    +      case _ => fail("Expected no shuffling")
    +    }
    +    var reachedRestore = false
    +    // Check that there should be no exchanges after 
`StateStoreRestoreExec`
    +    executedPlan.foreachUp { p =>
    +      if (reachedRestore) {
    +        assert(!p.isInstanceOf[Exchange], "There should be no further 
exchanges")
    +      } else {
    +        reachedRestore = p.isInstanceOf[StateStoreRestoreExec]
    +      }
    +    }
    +    true
    +  }
    +
    +  /** Add blocks of data to the `BlockRDDBackedSource`. */
    +  case class AddBlockData(source: BlockRDDBackedSource, data: Seq[Int]*) 
extends AddData {
    +    override def addData(query: Option[StreamExecution]): (Source, Offset) 
= {
    +      if (data.nonEmpty) {
    +        data.foreach(source.addData)
    +      } else {
    +        // we would like to create empty blockRDD's so add an empty block 
here.
    +        source.addData()
    +      }
    +      source.releaseLock()
    +      (source, LongOffset(source.counter))
    +    }
    +  }
    +
    +  test("SPARK-21977: coalesce(1) with 0 partition RDD should be 
repartitioned to 1") {
    +    val inputSource = new BlockRDDBackedSource(spark)
    +    MockSourceProvider.withMockSources(inputSource) {
    +      withTempDir { tempDir =>
    +        val aggregated: Dataset[Long] =
    +          spark.readStream
    +            .format((new MockSourceProvider).getClass.getCanonicalName)
    +            .load()
    +            .coalesce(1)
    +            .groupBy()
    +            .count()
    +            .as[Long]
    +
    +        testStream(aggregated, Complete())(
    +          AddBlockData(inputSource, Seq(1)),
    +          CheckLastBatch(1),
    +          AssertOnQuery("Verify no shuffling") { se =>
    +            checkAggregationChain(se, expectShuffling = false, 1)
    +          },
    +          AddBlockData(inputSource), // create an empty trigger
    +          CheckLastBatch(1),
    +          AssertOnQuery("Verify addition of exchange operator") { se =>
    +            checkAggregationChain(se, expectShuffling = true, 1)
    +          },
    +          AddBlockData(inputSource, Seq(2, 3)),
    +          CheckLastBatch(3),
    +          AddBlockData(inputSource),
    +          CheckLastBatch(3),
    +          StopStream
    +        )
    +      }
    +    }
    +  }
    +
    +  test("SPARK-21977: coalesce(1) should still be repartitioned when it has 
keyExpressions") {
    +    val inputSource = new BlockRDDBackedSource(spark)
    +    MockSourceProvider.withMockSources(inputSource) {
    +      withTempDir { tempDir =>
    +
    +        def createDf(partitions: Int): Dataset[(Long, Long)] = {
    +          spark.readStream
    +            .format((new MockSourceProvider).getClass.getCanonicalName)
    +            .load()
    +            .coalesce(partitions)
    +            .groupBy('a % 1) // just to give it a fake key
    +            .count()
    +            .as[(Long, Long)]
    +        }
    +
    +        testStream(createDf(1), Complete())(
    +          StartStream(checkpointLocation = tempDir.getAbsolutePath),
    +          AddBlockData(inputSource, Seq(1)),
    +          CheckLastBatch((0L, 1L)),
    +          AssertOnQuery("Verify addition of exchange operator") { se =>
    +            checkAggregationChain(
    +              se,
    +              expectShuffling = true,
    +              spark.sessionState.conf.numShufflePartitions)
    +          },
    +          StopStream
    +        )
    +
    +        testStream(createDf(2), Complete())(
    +          StartStream(checkpointLocation = tempDir.getAbsolutePath),
    +          Execute(se => se.processAllAvailable()),
    +          AddBlockData(inputSource, Seq(2), Seq(3), Seq(4)),
    +          CheckLastBatch((0L, 4L)),
    +          AssertOnQuery("Verify no exchange added") { se =>
    +            checkAggregationChain(
    +              se,
    +              expectShuffling = false,
    +              spark.sessionState.conf.numShufflePartitions)
    +          },
    +          AddBlockData(inputSource),
    +          CheckLastBatch((0L, 4L)),
    +          StopStream
    +        )
    +      }
    +    }
    +  }
    +}
    +
    +/**
    + * A Streaming Source that is backed by a BlockRDD and that can create 
RDDs with 0 blocks at will.
    + */
    +class BlockRDDBackedSource(spark: SparkSession) extends Source {
    +  var counter = 0L
    +  private val blockMgr = SparkEnv.get.blockManager
    +  private var blocks: Seq[BlockId] = Seq.empty
    +
    +  private var streamLock: CountDownLatch = new CountDownLatch(1)
    +
    +  def addData(data: Int*): Unit = {
    +    if (streamLock.getCount == 0) {
    +      streamLock = new CountDownLatch(1)
    --- End diff --
    
    This is complicated. See how AddFileData is implemented. It's much simpler.


---

---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org
For additional commands, e-mail: reviews-h...@spark.apache.org

Reply via email to