viirya commented on code in PR #44636: URL: https://github.com/apache/spark/pull/44636#discussion_r1447955436
########## sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSourceSuite.scala: ########## @@ -1153,174 +1153,213 @@ class FileStreamSourceSuite extends FileStreamSourceTest { } } - test("max files per trigger") { - withTempDir { case src => - var lastFileModTime: Option[Long] = None + test("maxFilesPerTrigger & maxBytesPerTrigger threshold logic must be obeyed") { + Seq( + "maxFilesPerTrigger", + "maxBytesPerTrigger" + ).foreach{ thresholdOption => + withTempDir { case src => + var lastFileModTime: Option[Long] = None + + /** Create a text file with a single data item */ + def createFile(data: String): File = { + val file = stringToFile(new File(src, s"$data.txt"), data) + if (lastFileModTime.nonEmpty) file.setLastModified(lastFileModTime.get + 1000) + lastFileModTime = Some(file.lastModified) + file + } - /** Create a text file with a single data item */ - def createFile(data: Int): File = { - val file = stringToFile(new File(src, s"$data.txt"), data.toString) - if (lastFileModTime.nonEmpty) file.setLastModified(lastFileModTime.get + 1000) - lastFileModTime = Some(file.lastModified) - file - } + createFile("a") + createFile("b") + createFile("c") - createFile(1) - createFile(2) - createFile(3) + // Set up a query to read text files 2 at a time + val df = spark + .readStream + .option(thresholdOption, 2) + .text(src.getCanonicalPath) + val q = df + .writeStream + .format("memory") + .queryName("file_data") + .start() + .asInstanceOf[StreamingQueryWrapper] + .streamingQuery + q.processAllAvailable() + val memorySink = q.sink.asInstanceOf[MemorySink] + val fileSource = getSourcesFromStreamingQuery(q).head + + /** Check the data read in the last batch */ + def checkLastBatchData(data: Char*): Unit = { + val schema = StructType(Seq(StructField("value", StringType))) + val df = spark.createDataFrame( + spark.sparkContext.makeRDD(memorySink.latestBatchData), schema) + checkAnswer(df, data.map(_.toString).toDF("value")) + } - // Set up a query to read text files 2 at a time - val df = spark - .readStream - .option("maxFilesPerTrigger", 2) - .text(src.getCanonicalPath) - val q = df - .writeStream - .format("memory") - .queryName("file_data") - .start() - .asInstanceOf[StreamingQueryWrapper] - .streamingQuery - q.processAllAvailable() - val memorySink = q.sink.asInstanceOf[MemorySink] - val fileSource = getSourcesFromStreamingQuery(q).head - - /** Check the data read in the last batch */ - def checkLastBatchData(data: Int*): Unit = { - val schema = StructType(Seq(StructField("value", StringType))) - val df = spark.createDataFrame( - spark.sparkContext.makeRDD(memorySink.latestBatchData), schema) - checkAnswer(df, data.map(_.toString).toDF("value")) - } + def checkAllData(data: Seq[Char]): Unit = { + val schema = StructType(Seq(StructField("value", StringType))) + val df = spark.createDataFrame( + spark.sparkContext.makeRDD(memorySink.allData), schema) + checkAnswer(df, data.map(_.toString).toDF("value")) + } - def checkAllData(data: Seq[Int]): Unit = { - val schema = StructType(Seq(StructField("value", StringType))) - val df = spark.createDataFrame( - spark.sparkContext.makeRDD(memorySink.allData), schema) - checkAnswer(df, data.map(_.toString).toDF("value")) - } + /** Check how many batches have executed since the last time this check was made */ + var lastBatchId = -1L + def checkNumBatchesSinceLastCheck(numBatches: Int): Unit = { + require(lastBatchId >= 0) + assert(memorySink.latestBatchId.get === lastBatchId + numBatches) + lastBatchId = memorySink.latestBatchId.get + } - /** Check how many batches have executed since the last time this check was made */ - var lastBatchId = -1L - def checkNumBatchesSinceLastCheck(numBatches: Int): Unit = { - require(lastBatchId >= 0) - assert(memorySink.latestBatchId.get === lastBatchId + numBatches) + checkLastBatchData('c') // (a and b) should be in batch 1, (c) should be in batch 2 (last) + checkAllData('a' to 'c') lastBatchId = memorySink.latestBatchId.get - } - checkLastBatchData(3) // (1 and 2) should be in batch 1, (3) should be in batch 2 (last) - checkAllData(1 to 3) - lastBatchId = memorySink.latestBatchId.get + fileSource.withBatchingLocked { + createFile("d") + createFile("e") // d and e should be in a batch + createFile("f") + createFile("g") // f and g should be in the last batch + } + q.processAllAvailable() + checkNumBatchesSinceLastCheck(2) + checkLastBatchData('f', 'g') + checkAllData('a' to 'g') + + fileSource.withBatchingLocked { + createFile("h") + createFile("i") // h and i should be in a batch + createFile("j") + createFile("k") // j and k should be in a batch + createFile("l") // l should be in the last batch + } + q.processAllAvailable() + checkNumBatchesSinceLastCheck(3) + checkLastBatchData('l') + checkAllData('a' to 'l') - fileSource.withBatchingLocked { - createFile(4) - createFile(5) // 4 and 5 should be in a batch - createFile(6) - createFile(7) // 6 and 7 should be in the last batch - } - q.processAllAvailable() - checkNumBatchesSinceLastCheck(2) - checkLastBatchData(6, 7) - checkAllData(1 to 7) - - fileSource.withBatchingLocked { - createFile(8) - createFile(9) // 8 and 9 should be in a batch - createFile(10) - createFile(11) // 10 and 11 should be in a batch - createFile(12) // 12 should be in the last batch + q.stop() } - q.processAllAvailable() - checkNumBatchesSinceLastCheck(3) - checkLastBatchData(12) - checkAllData(1 to 12) - - q.stop() } } - testQuietly("max files per trigger - incorrect values") { - val testTable = "maxFilesPerTrigger_test" - withTable(testTable) { - withTempDir { case src => - def testMaxFilePerTriggerValue(value: String): Unit = { - val df = spark.readStream.option("maxFilesPerTrigger", value).text(src.getCanonicalPath) - val e = intercept[StreamingQueryException] { - // Note: `maxFilesPerTrigger` is checked in the stream thread when creating the source - val q = df.writeStream.format("memory").queryName(testTable).start() - try { - q.processAllAvailable() - } finally { - q.stop() + testQuietly("max bytes per trigger & max files per trigger - incorrect values") { + Seq( + ("maxBytesPerTrigger_test", "maxBytesPerTrigger"), + ("maxFilesPerTrigger_test", "maxFilesPerTrigger") + ).foreach { case (testTable, optionName) => + withTable(testTable) { + withTempDir { case src => + def testMaxFilePerTriggerValue(value: String): Unit = { Review Comment: Maybe `testMaxFileOrBytesPerTriggerValue`? Or `testMaxOptionPerTriggerValue`. -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org For queries about this service, please contact Infrastructure at: us...@infra.apache.org --------------------------------------------------------------------- To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org For additional commands, e-mail: reviews-h...@spark.apache.org