dongjoon-hyun commented on code in PR #44636:
URL: https://github.com/apache/spark/pull/44636#discussion_r1448101899


##########
sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSourceSuite.scala:
##########
@@ -1153,174 +1153,213 @@ class FileStreamSourceSuite extends 
FileStreamSourceTest {
     }
   }
 
-  test("max files per trigger") {
-    withTempDir { case src =>
-      var lastFileModTime: Option[Long] = None
+  test("maxFilesPerTrigger & maxBytesPerTrigger threshold logic must be 
obeyed") {
+    Seq(
+      "maxFilesPerTrigger",
+      "maxBytesPerTrigger"
+    ).foreach{ thresholdOption =>
+      withTempDir { case src =>
+        var lastFileModTime: Option[Long] = None
+
+        /** Create a text file with a single data item */
+        def createFile(data: String): File = {
+          val file = stringToFile(new File(src, s"$data.txt"), data)
+          if (lastFileModTime.nonEmpty) 
file.setLastModified(lastFileModTime.get + 1000)
+          lastFileModTime = Some(file.lastModified)
+          file
+        }
 
-      /** Create a text file with a single data item */
-      def createFile(data: Int): File = {
-        val file = stringToFile(new File(src, s"$data.txt"), data.toString)
-        if (lastFileModTime.nonEmpty) file.setLastModified(lastFileModTime.get 
+ 1000)
-        lastFileModTime = Some(file.lastModified)
-        file
-      }
+        createFile("a")
+        createFile("b")
+        createFile("c")
 
-      createFile(1)
-      createFile(2)
-      createFile(3)
+        // Set up a query to read text files 2 at a time
+        val df = spark
+          .readStream
+          .option(thresholdOption, 2)
+          .text(src.getCanonicalPath)
+        val q = df
+          .writeStream
+          .format("memory")
+          .queryName("file_data")
+          .start()
+          .asInstanceOf[StreamingQueryWrapper]
+          .streamingQuery
+        q.processAllAvailable()
+        val memorySink = q.sink.asInstanceOf[MemorySink]
+        val fileSource = getSourcesFromStreamingQuery(q).head
+
+        /** Check the data read in the last batch */
+        def checkLastBatchData(data: Char*): Unit = {
+          val schema = StructType(Seq(StructField("value", StringType)))
+          val df = spark.createDataFrame(
+            spark.sparkContext.makeRDD(memorySink.latestBatchData), schema)
+          checkAnswer(df, data.map(_.toString).toDF("value"))
+        }
 
-      // Set up a query to read text files 2 at a time
-      val df = spark
-        .readStream
-        .option("maxFilesPerTrigger", 2)
-        .text(src.getCanonicalPath)
-      val q = df
-        .writeStream
-        .format("memory")
-        .queryName("file_data")
-        .start()
-        .asInstanceOf[StreamingQueryWrapper]
-        .streamingQuery
-      q.processAllAvailable()
-      val memorySink = q.sink.asInstanceOf[MemorySink]
-      val fileSource = getSourcesFromStreamingQuery(q).head
-
-      /** Check the data read in the last batch */
-      def checkLastBatchData(data: Int*): Unit = {
-        val schema = StructType(Seq(StructField("value", StringType)))
-        val df = spark.createDataFrame(
-          spark.sparkContext.makeRDD(memorySink.latestBatchData), schema)
-        checkAnswer(df, data.map(_.toString).toDF("value"))
-      }
+        def checkAllData(data: Seq[Char]): Unit = {
+          val schema = StructType(Seq(StructField("value", StringType)))
+          val df = spark.createDataFrame(
+            spark.sparkContext.makeRDD(memorySink.allData), schema)
+          checkAnswer(df, data.map(_.toString).toDF("value"))
+        }
 
-      def checkAllData(data: Seq[Int]): Unit = {
-        val schema = StructType(Seq(StructField("value", StringType)))
-        val df = spark.createDataFrame(
-          spark.sparkContext.makeRDD(memorySink.allData), schema)
-        checkAnswer(df, data.map(_.toString).toDF("value"))
-      }
+        /** Check how many batches have executed since the last time this 
check was made */
+        var lastBatchId = -1L
+        def checkNumBatchesSinceLastCheck(numBatches: Int): Unit = {
+          require(lastBatchId >= 0)
+          assert(memorySink.latestBatchId.get === lastBatchId + numBatches)
+          lastBatchId = memorySink.latestBatchId.get
+        }
 
-      /** Check how many batches have executed since the last time this check 
was made */
-      var lastBatchId = -1L
-      def checkNumBatchesSinceLastCheck(numBatches: Int): Unit = {
-        require(lastBatchId >= 0)
-        assert(memorySink.latestBatchId.get === lastBatchId + numBatches)
+        checkLastBatchData('c')  // (a and b) should be in batch 1, (c) should 
be in batch 2 (last)
+        checkAllData('a' to 'c')
         lastBatchId = memorySink.latestBatchId.get
-      }
 
-      checkLastBatchData(3)  // (1 and 2) should be in batch 1, (3) should be 
in batch 2 (last)
-      checkAllData(1 to 3)
-      lastBatchId = memorySink.latestBatchId.get
+        fileSource.withBatchingLocked {
+          createFile("d")
+          createFile("e")   // d and e should be in a batch
+          createFile("f")
+          createFile("g")   // f and g should be in the last batch
+        }
+        q.processAllAvailable()
+        checkNumBatchesSinceLastCheck(2)
+        checkLastBatchData('f', 'g')
+        checkAllData('a' to 'g')
+
+        fileSource.withBatchingLocked {
+          createFile("h")
+          createFile("i")    // h and i should be in a batch
+          createFile("j")
+          createFile("k")   // j and k should be in a batch
+          createFile("l")   // l should be in the last batch
+        }
+        q.processAllAvailable()
+        checkNumBatchesSinceLastCheck(3)
+        checkLastBatchData('l')
+        checkAllData('a' to 'l')
 
-      fileSource.withBatchingLocked {
-        createFile(4)
-        createFile(5)   // 4 and 5 should be in a batch
-        createFile(6)
-        createFile(7)   // 6 and 7 should be in the last batch
-      }
-      q.processAllAvailable()
-      checkNumBatchesSinceLastCheck(2)
-      checkLastBatchData(6, 7)
-      checkAllData(1 to 7)
-
-      fileSource.withBatchingLocked {
-        createFile(8)
-        createFile(9)    // 8 and 9 should be in a batch
-        createFile(10)
-        createFile(11)   // 10 and 11 should be in a batch
-        createFile(12)   // 12 should be in the last batch
+        q.stop()
       }
-      q.processAllAvailable()
-      checkNumBatchesSinceLastCheck(3)
-      checkLastBatchData(12)
-      checkAllData(1 to 12)
-
-      q.stop()
     }
   }
 
-  testQuietly("max files per trigger - incorrect values") {
-    val testTable = "maxFilesPerTrigger_test"
-    withTable(testTable) {
-      withTempDir { case src =>
-        def testMaxFilePerTriggerValue(value: String): Unit = {
-          val df = spark.readStream.option("maxFilesPerTrigger", 
value).text(src.getCanonicalPath)
-          val e = intercept[StreamingQueryException] {
-            // Note: `maxFilesPerTrigger` is checked in the stream thread when 
creating the source
-            val q = 
df.writeStream.format("memory").queryName(testTable).start()
-            try {
-              q.processAllAvailable()
-            } finally {
-              q.stop()
+  testQuietly("max bytes per trigger & max files per trigger - incorrect 
values") {

Review Comment:
   ditto.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org

For queries about this service, please contact Infrastructure at:
us...@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org
For additional commands, e-mail: reviews-h...@spark.apache.org

Reply via email to