showuon commented on code in PR #12347:
URL: https://github.com/apache/kafka/pull/12347#discussion_r921131155


##########
core/src/test/scala/unit/kafka/log/LogManagerTest.scala:
##########
@@ -638,6 +641,221 @@ class LogManagerTest {
     assertTrue(logManager.partitionsInitializing.isEmpty)
   }
 
+  private def appendRecordsToLog(time: MockTime, parentLogDir: File, 
partitionId: Int, brokerTopicStats: BrokerTopicStats, expectedSegmentsPerLog: 
Int): Unit = {
+    def createRecords = TestUtils.singletonRecords(value = "test".getBytes, 
timestamp = time.milliseconds)
+    val tpFile = new File(parentLogDir, s"$name-$partitionId")
+
+    val log = LogTestUtils.createLog(tpFile, logConfig, brokerTopicStats, 
time.scheduler, time, 0, 0,
+      5 * 60 * 1000, 60 * 60 * 1000, 
LogManager.ProducerIdExpirationCheckIntervalMs)
+
+    val numMessages = 20
+    try {
+      for (_ <- 0 until numMessages) {
+        log.appendAsLeader(createRecords, leaderEpoch = 0)
+      }
+
+      assertEquals(expectedSegmentsPerLog, log.numberOfSegments)
+    } finally {
+      log.close()
+    }
+  }
+
+  private def verifyRemainingLogsToRecoverMetric(spyLogManager: LogManager, 
expectedParams: Map[String, Int]): Unit = {
+    val spyLogManagerClassName = spyLogManager.getClass().getSimpleName
+    // get all `remainingLogsToRecover` metrics
+    val logMetrics: ArrayBuffer[Gauge[Int]] = 
KafkaYammerMetrics.defaultRegistry.allMetrics.asScala
+      .filter { case (metric, _) => metric.getType == 
s"$spyLogManagerClassName" && metric.getName == "remainingLogsToRecover" }
+      .map { case (_, gauge) => gauge }
+      .asInstanceOf[ArrayBuffer[Gauge[Int]]]
+
+    assertEquals(expectedParams.size, logMetrics.size)
+
+    val capturedPath: ArgumentCaptor[String] = 
ArgumentCaptor.forClass(classOf[String])
+    val capturedNumRemainingLogs: ArgumentCaptor[Int] = 
ArgumentCaptor.forClass(classOf[Int])
+
+    // Since we'll update numRemainingLogs from totalLogs to 0 for each log 
dir, so we need to add 1 here
+    val expectedCallTimes = expectedParams.values.map( num => num + 1 ).sum
+    verify(spyLogManager, 
times(expectedCallTimes)).updateNumRemainingLogs(any, capturedPath.capture(), 
capturedNumRemainingLogs.capture());
+
+    val paths = capturedPath.getAllValues
+    val numRemainingLogs = capturedNumRemainingLogs.getAllValues
+
+    // expected the end value is 0
+    logMetrics.foreach { gauge => assertEquals(0, gauge.value()) }
+
+    expectedParams.foreach {
+      case (path, totalLogs) =>
+        // make sure we update the numRemainingLogs from totalLogs to 0 in 
order for each log dir
+        var expectedCurRemainingLogs = totalLogs + 1
+        for (i <- 0 until paths.size()) {
+          if (paths.get(i).contains(path)) {
+            expectedCurRemainingLogs -= 1
+            assertEquals(expectedCurRemainingLogs, numRemainingLogs.get(i))
+          }
+        }
+        assertEquals(0, expectedCurRemainingLogs)
+    }
+  }
+
+  private def verifyRemainingSegmentsToRecoverMetric(spyLogManager: LogManager,
+                                                     logDirs: Seq[File],
+                                                     
recoveryThreadsPerDataDir: Int,
+                                                     mockMap: 
ConcurrentHashMap[String, Int],
+                                                     expectedParams: 
Map[String, Int]): Unit = {
+    val spyLogManagerClassName = spyLogManager.getClass().getSimpleName
+    // get all `remainingSegmentsToRecover` metrics
+    val logSegmentMetrics: ArrayBuffer[Gauge[Int]] = 
KafkaYammerMetrics.defaultRegistry.allMetrics.asScala
+          .filter { case (metric, _) => metric.getType == 
s"$spyLogManagerClassName" && metric.getName == "remainingSegmentsToRecover" }
+          .map { case (_, gauge) => gauge }
+          .asInstanceOf[ArrayBuffer[Gauge[Int]]]
+
+    // expected each log dir has 2 metrics for each thread
+    assertEquals(recoveryThreadsPerDataDir * logDirs.size, 
logSegmentMetrics.size)
+
+    val capturedThreadName: ArgumentCaptor[String] = 
ArgumentCaptor.forClass(classOf[String])
+    val capturedNumRemainingSegments: ArgumentCaptor[Int] = 
ArgumentCaptor.forClass(classOf[Int])
+
+    // Since we'll update numRemainingSegments from totalSegments to 0 for 
each thread, so we need to add 1 here
+    val expectedCallTimes = expectedParams.values.map( num => num + 1 ).sum
+    verify(mockMap, 
times(expectedCallTimes)).put(capturedThreadName.capture(), 
capturedNumRemainingSegments.capture());
+
+    // expected the end value is 0
+    logSegmentMetrics.foreach { gauge => assertEquals(0, gauge.value()) }
+
+    val threadNames = capturedThreadName.getAllValues
+    val numRemainingSegments = capturedNumRemainingSegments.getAllValues
+
+    expectedParams.foreach {
+      case (threadName, totalSegments) =>
+        // make sure we update the numRemainingSegments from totalSegments to 
0 in order for each thread
+        var expectedCurRemainingSegments = totalSegments + 1
+        for (i <- 0 until threadNames.size) {
+          if (threadNames.get(i).contains(threadName)) {
+            expectedCurRemainingSegments -= 1
+            assertEquals(expectedCurRemainingSegments, 
numRemainingSegments.get(i))
+          }
+        }
+        assertEquals(0, expectedCurRemainingSegments)
+    }
+  }
+
+  private def verifyLogRecoverMetricsRemoved(spyLogManager: LogManager): Unit 
= {
+    val spyLogManagerClassName = spyLogManager.getClass().getSimpleName
+    // get all `remainingLogsToRecover` metrics
+    def logMetrics: mutable.Set[MetricName] = 
KafkaYammerMetrics.defaultRegistry.allMetrics.keySet.asScala
+      .filter { metric => metric.getType == s"$spyLogManagerClassName" && 
metric.getName == "remainingLogsToRecover" }
+
+    assertTrue(logMetrics.isEmpty)
+
+    // get all `remainingSegmentsToRecover` metrics
+    val logSegmentMetrics: mutable.Set[MetricName] = 
KafkaYammerMetrics.defaultRegistry.allMetrics.keySet.asScala
+      .filter { metric => metric.getType == s"$spyLogManagerClassName" && 
metric.getName == "remainingSegmentsToRecover" }
+
+    assertTrue(logSegmentMetrics.isEmpty)
+  }
+
+  @Test
+  def testLogRecoveryMetrics(): Unit = {
+    logManager.shutdown()
+    val logDir1 = TestUtils.tempDir()
+    val logDir2 = TestUtils.tempDir()
+    val logDirs = Seq(logDir1, logDir2)
+    val recoveryThreadsPerDataDir = 2
+    // create logManager with expected recovery thread number
+    logManager = createLogManager(logDirs, recoveryThreadsPerDataDir = 
recoveryThreadsPerDataDir)
+    val spyLogManager = spy(logManager)
+
+    assertEquals(2, spyLogManager.liveLogDirs.size)
+
+    val logConfig = LogTestUtils.createLogConfig(segmentBytes = 1000, 
indexIntervalBytes = 1, maxMessageBytes = 64 * 1024)
+    val mockTime = new MockTime()
+    val mockMap = mock(classOf[ConcurrentHashMap[String, Int]])
+    val mockBrokerTopicStats = mock(classOf[BrokerTopicStats])
+    val expectedSegmentsPerLog = 2
+
+    // create log segments for log recovery in each log dir
+    appendRecordsToLog(mockTime, logDir1, 0, mockBrokerTopicStats, 
expectedSegmentsPerLog)
+    appendRecordsToLog(mockTime, logDir2, 1, mockBrokerTopicStats, 
expectedSegmentsPerLog)
+
+    // intercept loadLog method to pass expected parameter to do log recovery
+    doAnswer { invocation =>
+      val dir: File = invocation.getArgument(0)
+      val topicConfigOverrides: mutable.Map[String, LogConfig] = 
invocation.getArgument(5)
+
+      val topicPartition = UnifiedLog.parseTopicPartitionName(dir)
+      val config = topicConfigOverrides.getOrElse(topicPartition.topic, 
logConfig)
+
+      UnifiedLog(
+        dir = dir,
+        config = config,
+        logStartOffset = 0,
+        recoveryPoint = 0,
+        maxTransactionTimeoutMs = 5 * 60 * 1000,
+        maxProducerIdExpirationMs = 5 * 60 * 1000,
+        producerIdExpirationCheckIntervalMs = 
LogManager.ProducerIdExpirationCheckIntervalMs,
+        scheduler = mockTime.scheduler,
+        time = mockTime,
+        brokerTopicStats = mockBrokerTopicStats,
+        logDirFailureChannel = mock(classOf[LogDirFailureChannel]),
+        // not clean shutdown
+        lastShutdownClean = false,
+        topicId = None,
+        keepPartitionMetadataFile = false,
+        // pass mock map for verification later
+        numRemainingSegments = mockMap)
+
+    } .when(spyLogManager).loadLog(any[File], any[Boolean], 
any[Map[TopicPartition, Long]], any[Map[TopicPartition, Long]],
+      any[LogConfig], any[Map[String, LogConfig]], any[ConcurrentMap[String, 
Int]])
+
+    // do nothing for removeLogRecoveryMetrics for metrics verification
+    doNothing().when(spyLogManager).removeLogRecoveryMetrics(anyInt)
+
+    // start the logManager to do log recovery
+    spyLogManager.startup(Set.empty)
+
+    // make sure log recovery metrics are added and removed
+    verify(spyLogManager, times(1)).addLogRecoveryMetrics(any, any, 
ArgumentMatchers.eq(recoveryThreadsPerDataDir))
+    verify(spyLogManager, 
times(1)).removeLogRecoveryMetrics(ArgumentMatchers.eq(recoveryThreadsPerDataDir))
+
+    // expected 1 log in each log dir since we created 2 partitions with 2 log 
dirs
+    val expectedRemainingLogsParams = Map[String, Int](logDir1.getAbsolutePath 
-> 1, logDir2.getAbsolutePath -> 1)
+    verifyRemainingLogsToRecoverMetric(spyLogManager, 
expectedRemainingLogsParams)
+
+    val expectedRemainingSegmentsParams = Map[String, Int](
+      logDir1.getAbsolutePath -> expectedSegmentsPerLog, 
logDir2.getAbsolutePath -> expectedSegmentsPerLog)
+    verifyRemainingSegmentsToRecoverMetric(spyLogManager, logDirs, 
recoveryThreadsPerDataDir, mockMap, expectedRemainingSegmentsParams)
+  }
+
+  @Test
+  def testLogRecoveryMetricsShouldBeRemovedAfterLogRecovered(): Unit = {
+    logManager.shutdown()
+    val logDir1 = TestUtils.tempDir()
+    val logDir2 = TestUtils.tempDir()
+    val logDirs = Seq(logDir1, logDir2)
+    val recoveryThreadsPerDataDir = 2
+    // create logManager with expected recovery thread number
+    logManager = createLogManager(logDirs, recoveryThreadsPerDataDir = 
recoveryThreadsPerDataDir)
+    val spyLogManager = spy(logManager)
+
+    assertEquals(2, spyLogManager.liveLogDirs.size)
+
+    // intercept loadLog method to pass expected parameter to do log recovery
+    doAnswer { _ =>
+      // simulate the recovery thread is resized during log recovery
+      spyLogManager.resizeRecoveryThreadPool(recoveryThreadsPerDataDir - 1)

Review Comment:
   Ah, good point again! Yes, I confirmed in both ZK mode and Kraft mode, that 
broker will be online(ready) after log recovery. So it has no chance to change 
the config during log recovery. Test updated. Thanks.
   
   cc @tombentley 



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: jira-unsubscr...@kafka.apache.org

For queries about this service, please contact Infrastructure at:
us...@infra.apache.org

Reply via email to