This is an automated email from the ASF dual-hosted git repository.

junrao pushed a commit to branch trunk
in repository https://gitbox.apache.org/repos/asf/kafka.git


The following commit(s) were added to refs/heads/trunk by this push:
     new 3835515feaf KAFKA-16541 Fix potential leader-epoch checkpoint file 
corruption (#15993)
3835515feaf is described below

commit 3835515feaf7cb5bb7de3c4d63794e79100eb62a
Author: Okada Haruki <ocadar...@gmail.com>
AuthorDate: Thu Jun 6 15:10:13 2024 +0900

    KAFKA-16541 Fix potential leader-epoch checkpoint file corruption (#15993)
    
    A patch for KAFKA-15046 got rid of fsync on 
LeaderEpochFileCache#truncateFromStart/End for performance reason, but it 
turned out this could cause corrupted leader-epoch checkpoint file on 
ungraceful OS shutdown, i.e. OS shuts down in the middle when kernel is writing 
dirty pages back to the device.
    
    To address this problem, this PR makes below changes: (1) Revert 
LeaderEpochCheckpoint#write to always fsync
    (2) truncateFromStart/End now call LeaderEpochCheckpoint#write 
asynchronously on scheduler thread
    (3) UnifiedLog#maybeCreateLeaderEpochCache now loads epoch entries from 
checkpoint file only when current cache is absent
    
    Reviewers: Jun Rao <jun...@gmail.com>
---
 .../java/kafka/log/remote/RemoteLogManager.java    |  46 ++++--
 core/src/main/scala/kafka/log/LogLoader.scala      |   4 +-
 core/src/main/scala/kafka/log/UnifiedLog.scala     |  43 +++---
 .../server/checkpoints/OffsetCheckpointFile.scala  |   2 +-
 .../kafka/log/remote/RemoteLogManagerTest.java     | 120 ++++++++-------
 .../unit/kafka/cluster/PartitionLockTest.scala     |   3 +-
 .../scala/unit/kafka/cluster/PartitionTest.scala   |   3 +-
 .../unit/kafka/log/LogCleanerManagerTest.scala     |   3 +-
 .../test/scala/unit/kafka/log/LogCleanerTest.scala |   5 +-
 .../test/scala/unit/kafka/log/LogLoaderTest.scala  |  23 ++-
 .../test/scala/unit/kafka/log/LogSegmentTest.scala |  16 +-
 .../unit/kafka/server/ReplicaManagerTest.scala     |   5 +-
 .../InMemoryLeaderEpochCheckpointTest.scala        |  58 -------
 ...ffsetCheckpointFileWithFailureHandlerTest.scala |  18 ++-
 .../server/epoch/LeaderEpochFileCacheTest.scala    |  64 ++++----
 .../scala/unit/kafka/utils/SchedulerTest.scala     |   3 +-
 .../apache/kafka/server/common/CheckpointFile.java |   8 +-
 .../CheckpointFileWithFailureHandler.java          |  22 ++-
 .../checkpoint/InMemoryLeaderEpochCheckpoint.java  |  63 --------
 .../checkpoint/LeaderEpochCheckpoint.java          |  34 -----
 .../checkpoint/LeaderEpochCheckpointFile.java      |  11 +-
 .../internals/epoch/LeaderEpochFileCache.java      | 168 ++++++++++++++++-----
 22 files changed, 364 insertions(+), 358 deletions(-)

diff --git a/core/src/main/java/kafka/log/remote/RemoteLogManager.java 
b/core/src/main/java/kafka/log/remote/RemoteLogManager.java
index 5b0d91ff439..c524238f623 100644
--- a/core/src/main/java/kafka/log/remote/RemoteLogManager.java
+++ b/core/src/main/java/kafka/log/remote/RemoteLogManager.java
@@ -46,6 +46,7 @@ import org.apache.kafka.common.utils.KafkaThread;
 import org.apache.kafka.common.utils.LogContext;
 import org.apache.kafka.common.utils.Time;
 import org.apache.kafka.common.utils.Utils;
+import org.apache.kafka.server.common.CheckpointFile;
 import org.apache.kafka.server.common.OffsetAndEpoch;
 import org.apache.kafka.server.config.ServerConfigs;
 import 
org.apache.kafka.server.log.remote.metadata.storage.ClassLoaderAwareRemoteLogMetadataManager;
@@ -61,7 +62,7 @@ import 
org.apache.kafka.server.log.remote.storage.RemoteLogSegmentState;
 import org.apache.kafka.server.log.remote.storage.RemoteStorageException;
 import org.apache.kafka.server.log.remote.storage.RemoteStorageManager;
 import org.apache.kafka.server.metrics.KafkaMetricsGroup;
-import 
org.apache.kafka.storage.internals.checkpoint.InMemoryLeaderEpochCheckpoint;
+import org.apache.kafka.storage.internals.checkpoint.LeaderEpochCheckpointFile;
 import org.apache.kafka.storage.internals.epoch.LeaderEpochFileCache;
 import org.apache.kafka.storage.internals.log.AbortedTxn;
 import org.apache.kafka.storage.internals.log.EpochEntry;
@@ -83,12 +84,16 @@ import org.slf4j.LoggerFactory;
 import scala.Option;
 import scala.collection.JavaConverters;
 
+import java.io.BufferedWriter;
+import java.io.ByteArrayOutputStream;
 import java.io.Closeable;
 import java.io.File;
 import java.io.IOException;
 import java.io.InputStream;
+import java.io.OutputStreamWriter;
 import java.lang.reflect.InvocationTargetException;
 import java.nio.ByteBuffer;
+import java.nio.charset.StandardCharsets;
 import java.nio.file.Path;
 import java.security.PrivilegedAction;
 import java.util.ArrayList;
@@ -612,25 +617,23 @@ public class RemoteLogManager implements Closeable {
     }
 
     /**
-     * Returns the leader epoch checkpoint by truncating with the given 
start[exclusive] and end[inclusive] offset
+     * Returns the leader epoch entries within the range of the given 
start[exclusive] and end[inclusive] offset.
+     * <p>
+     * Visible for testing.
      *
      * @param log         The actual log from where to take the leader-epoch 
checkpoint
-     * @param startOffset The start offset of the checkpoint file (exclusive 
in the truncation).
+     * @param startOffset The start offset of the epoch entries (inclusive).
      *                    If start offset is 6, then it will retain an entry 
at offset 6.
-     * @param endOffset   The end offset of the checkpoint file (inclusive in 
the truncation)
+     * @param endOffset   The end offset of the epoch entries (exclusive)
      *                    If end offset is 100, then it will remove the 
entries greater than or equal to 100.
-     * @return the truncated leader epoch checkpoint
+     * @return the leader epoch entries
      */
-    InMemoryLeaderEpochCheckpoint getLeaderEpochCheckpoint(UnifiedLog log, 
long startOffset, long endOffset) {
-        InMemoryLeaderEpochCheckpoint checkpoint = new 
InMemoryLeaderEpochCheckpoint();
+    List<EpochEntry> getLeaderEpochEntries(UnifiedLog log, long startOffset, 
long endOffset) {
         if (log.leaderEpochCache().isDefined()) {
-            LeaderEpochFileCache cache = 
log.leaderEpochCache().get().writeTo(checkpoint);
-            if (startOffset >= 0) {
-                cache.truncateFromStart(startOffset);
-            }
-            cache.truncateFromEnd(endOffset);
+            return 
log.leaderEpochCache().get().epochEntriesInRange(startOffset, endOffset);
+        } else {
+            return Collections.emptyList();
         }
-        return checkpoint;
     }
 
     class RLMTask extends CancellableRunnable {
@@ -788,7 +791,7 @@ public class RemoteLogManager implements Closeable {
             long endOffset = nextSegmentBaseOffset - 1;
             File producerStateSnapshotFile = 
log.producerStateManager().fetchSnapshot(nextSegmentBaseOffset).orElse(null);
 
-            List<EpochEntry> epochEntries = getLeaderEpochCheckpoint(log, 
segment.baseOffset(), nextSegmentBaseOffset).read();
+            List<EpochEntry> epochEntries = getLeaderEpochEntries(log, 
segment.baseOffset(), nextSegmentBaseOffset);
             Map<Integer, Long> segmentLeaderEpochs = new 
HashMap<>(epochEntries.size());
             epochEntries.forEach(entry -> segmentLeaderEpochs.put(entry.epoch, 
entry.startOffset));
 
@@ -798,7 +801,7 @@ public class RemoteLogManager implements Closeable {
 
             
remoteLogMetadataManager.addRemoteLogSegmentMetadata(copySegmentStartedRlsm).get();
 
-            ByteBuffer leaderEpochsIndex = getLeaderEpochCheckpoint(log, -1, 
nextSegmentBaseOffset).readAsByteBuffer();
+            ByteBuffer leaderEpochsIndex = 
epochEntriesAsByteBuffer(getLeaderEpochEntries(log, -1, nextSegmentBaseOffset));
             LogSegmentData segmentData = new LogSegmentData(logFile.toPath(), 
toPathIfExists(segment.offsetIndex().file()),
                     toPathIfExists(segment.timeIndex().file()), 
Optional.ofNullable(toPathIfExists(segment.txnIndex().file())),
                     producerStateSnapshotFile.toPath(), leaderEpochsIndex);
@@ -1751,6 +1754,19 @@ public class RemoteLogManager implements Closeable {
         LOGGER.info("Shutting down of thread pool {} is completed", poolName);
     }
 
+    //Visible for testing
+    static ByteBuffer epochEntriesAsByteBuffer(List<EpochEntry> epochEntries) 
throws IOException {
+        ByteArrayOutputStream stream = new ByteArrayOutputStream();
+        try (BufferedWriter writer = new BufferedWriter(new 
OutputStreamWriter(stream, StandardCharsets.UTF_8))) {
+            CheckpointFile.CheckpointWriteBuffer<EpochEntry> writeBuffer =
+                    new CheckpointFile.CheckpointWriteBuffer<>(writer, 0, 
LeaderEpochCheckpointFile.FORMATTER);
+            writeBuffer.write(epochEntries);
+            writer.flush();
+        }
+
+        return ByteBuffer.wrap(stream.toByteArray());
+    }
+
     private void removeRemoteTopicPartitionMetrics(TopicIdPartition 
topicIdPartition) {
         String topic = topicIdPartition.topic();
         if (!brokerTopicStats.isTopicStatsExisted(topicIdPartition.topic())) {
diff --git a/core/src/main/scala/kafka/log/LogLoader.scala 
b/core/src/main/scala/kafka/log/LogLoader.scala
index b0f1fdd0e1c..b3b0ec2c633 100644
--- a/core/src/main/scala/kafka/log/LogLoader.scala
+++ b/core/src/main/scala/kafka/log/LogLoader.scala
@@ -173,14 +173,14 @@ class LogLoader(
       }
     }
 
-    leaderEpochCache.ifPresent(_.truncateFromEnd(nextOffset))
+    leaderEpochCache.ifPresent(_.truncateFromEndAsyncFlush(nextOffset))
     val newLogStartOffset = if (isRemoteLogEnabled) {
       logStartOffsetCheckpoint
     } else {
       math.max(logStartOffsetCheckpoint, segments.firstSegment.get.baseOffset)
     }
     // The earliest leader epoch may not be flushed during a hard failure. 
Recover it here.
-    leaderEpochCache.ifPresent(_.truncateFromStart(logStartOffsetCheckpoint))
+    
leaderEpochCache.ifPresent(_.truncateFromStartAsyncFlush(logStartOffsetCheckpoint))
 
     // Any segment loading or recovery code must not use producerStateManager, 
so that we can build the full state here
     // from scratch.
diff --git a/core/src/main/scala/kafka/log/UnifiedLog.scala 
b/core/src/main/scala/kafka/log/UnifiedLog.scala
index ba1c8656a88..bef18806b0d 100644
--- a/core/src/main/scala/kafka/log/UnifiedLog.scala
+++ b/core/src/main/scala/kafka/log/UnifiedLog.scala
@@ -521,7 +521,8 @@ class UnifiedLog(@volatile var logStartOffset: Long,
   }
 
   private def initializeLeaderEpochCache(): Unit = lock synchronized {
-    leaderEpochCache = UnifiedLog.maybeCreateLeaderEpochCache(dir, 
topicPartition, logDirFailureChannel, recordVersion, logIdent)
+    leaderEpochCache = UnifiedLog.maybeCreateLeaderEpochCache(
+      dir, topicPartition, logDirFailureChannel, recordVersion, logIdent, 
leaderEpochCache, scheduler)
   }
 
   private def updateHighWatermarkWithLogEndOffset(): Unit = {
@@ -1015,7 +1016,7 @@ class UnifiedLog(@volatile var logStartOffset: Long,
           updatedLogStartOffset = true
           updateLogStartOffset(newLogStartOffset)
           info(s"Incremented log start offset to $newLogStartOffset due to 
$reason")
-          leaderEpochCache.foreach(_.truncateFromStart(logStartOffset))
+          
leaderEpochCache.foreach(_.truncateFromStartAsyncFlush(logStartOffset))
           producerStateManager.onLogStartOffsetIncremented(newLogStartOffset)
           maybeIncrementFirstUnstableOffset()
         }
@@ -1813,7 +1814,7 @@ class UnifiedLog(@volatile var logStartOffset: Long,
         // and inserted the first start offset entry, but then failed to 
append any entries
         // before another leader was elected.
         lock synchronized {
-          leaderEpochCache.foreach(_.truncateFromEnd(logEndOffset))
+          leaderEpochCache.foreach(_.truncateFromEndAsyncFlush(logEndOffset))
         }
 
         false
@@ -1826,7 +1827,7 @@ class UnifiedLog(@volatile var logStartOffset: Long,
           } else {
             val deletedSegments = localLog.truncateTo(targetOffset)
             deleteProducerSnapshots(deletedSegments, asyncDelete = true)
-            leaderEpochCache.foreach(_.truncateFromEnd(targetOffset))
+            leaderEpochCache.foreach(_.truncateFromEndAsyncFlush(targetOffset))
             logStartOffset = math.min(targetOffset, logStartOffset)
             rebuildProducerState(targetOffset, producerStateManager)
             if (highWatermark >= localLog.logEndOffset)
@@ -2011,12 +2012,17 @@ object UnifiedLog extends Logging {
     Files.createDirectories(dir.toPath)
     val topicPartition = UnifiedLog.parseTopicPartitionName(dir)
     val segments = new LogSegments(topicPartition)
+    // The created leaderEpochCache will be truncated by LogLoader if necessary
+    // so it is guaranteed that the epoch entries will be correct even when 
on-disk
+    // checkpoint was stale (due to async nature of 
LeaderEpochFileCache#truncateFromStart/End).
     val leaderEpochCache = UnifiedLog.maybeCreateLeaderEpochCache(
       dir,
       topicPartition,
       logDirFailureChannel,
       config.recordVersion,
-      s"[UnifiedLog partition=$topicPartition, dir=${dir.getParent}] ")
+      s"[UnifiedLog partition=$topicPartition, dir=${dir.getParent}] ",
+      None,
+      scheduler)
     val producerStateManager = new ProducerStateManager(topicPartition, dir,
       maxTransactionTimeoutMs, producerStateManagerConfig, time)
     val isRemoteLogEnabled = 
UnifiedLog.isRemoteLogEnabled(remoteStorageSystemEnable, config, 
topicPartition.topic)
@@ -2103,7 +2109,8 @@ object UnifiedLog extends Logging {
   }
 
   /**
-   * If the recordVersion is >= RecordVersion.V2, then create and return a 
LeaderEpochFileCache.
+   * If the recordVersion is >= RecordVersion.V2, create a new 
LeaderEpochFileCache instance.
+   * Loading the epoch entries from the backing checkpoint file or the 
provided currentCache if not empty.
    * Otherwise, the message format is considered incompatible and the existing 
LeaderEpoch file
    * is deleted.
    *
@@ -2112,33 +2119,29 @@ object UnifiedLog extends Logging {
    * @param logDirFailureChannel The LogDirFailureChannel to asynchronously 
handle log dir failure
    * @param recordVersion        The record version
    * @param logPrefix            The logging prefix
+   * @param currentCache         The current LeaderEpochFileCache instance (if 
any)
+   * @param scheduler            The scheduler for executing asynchronous tasks
    * @return The new LeaderEpochFileCache instance (if created), none otherwise
    */
   def maybeCreateLeaderEpochCache(dir: File,
                                   topicPartition: TopicPartition,
                                   logDirFailureChannel: LogDirFailureChannel,
                                   recordVersion: RecordVersion,
-                                  logPrefix: String): 
Option[LeaderEpochFileCache] = {
+                                  logPrefix: String,
+                                  currentCache: Option[LeaderEpochFileCache],
+                                  scheduler: Scheduler): 
Option[LeaderEpochFileCache] = {
     val leaderEpochFile = LeaderEpochCheckpointFile.newFile(dir)
 
-    def newLeaderEpochFileCache(): LeaderEpochFileCache = {
-      val checkpointFile = new LeaderEpochCheckpointFile(leaderEpochFile, 
logDirFailureChannel)
-      new LeaderEpochFileCache(topicPartition, checkpointFile)
-    }
-
     if (recordVersion.precedes(RecordVersion.V2)) {
-      val currentCache = if (leaderEpochFile.exists())
-        Some(newLeaderEpochFileCache())
-      else
-        None
-
-      if (currentCache.exists(_.nonEmpty))
+      if (leaderEpochFile.exists()) {
         warn(s"${logPrefix}Deleting non-empty leader epoch cache due to 
incompatible message format $recordVersion")
-
+      }
       Files.deleteIfExists(leaderEpochFile.toPath)
       None
     } else {
-      Some(newLeaderEpochFileCache())
+      val checkpointFile = new LeaderEpochCheckpointFile(leaderEpochFile, 
logDirFailureChannel)
+      currentCache.map(_.withCheckpoint(checkpointFile))
+        .orElse(Some(new LeaderEpochFileCache(topicPartition, checkpointFile, 
scheduler)))
     }
   }
 
diff --git 
a/core/src/main/scala/kafka/server/checkpoints/OffsetCheckpointFile.scala 
b/core/src/main/scala/kafka/server/checkpoints/OffsetCheckpointFile.scala
index de3283d21fd..084e46c5ef2 100644
--- a/core/src/main/scala/kafka/server/checkpoints/OffsetCheckpointFile.scala
+++ b/core/src/main/scala/kafka/server/checkpoints/OffsetCheckpointFile.scala
@@ -68,7 +68,7 @@ class OffsetCheckpointFile(val file: File, 
logDirFailureChannel: LogDirFailureCh
   def write(offsets: Map[TopicPartition, Long]): Unit = {
     val list: java.util.List[(TopicPartition, Long)] = new 
java.util.ArrayList[(TopicPartition, Long)](offsets.size)
     offsets.foreach(x => list.add(x))
-    checkpoint.write(list, true)
+    checkpoint.write(list)
   }
 
   def read(): Map[TopicPartition, Long] = {
diff --git a/core/src/test/java/kafka/log/remote/RemoteLogManagerTest.java 
b/core/src/test/java/kafka/log/remote/RemoteLogManagerTest.java
index 50b581fdf4e..19fa4a8a443 100644
--- a/core/src/test/java/kafka/log/remote/RemoteLogManagerTest.java
+++ b/core/src/test/java/kafka/log/remote/RemoteLogManagerTest.java
@@ -60,14 +60,15 @@ import 
org.apache.kafka.server.log.remote.storage.RemoteStorageManager;
 import 
org.apache.kafka.server.log.remote.storage.RemoteStorageManager.IndexType;
 import org.apache.kafka.server.metrics.KafkaMetricsGroup;
 import org.apache.kafka.server.metrics.KafkaYammerMetrics;
-import 
org.apache.kafka.storage.internals.checkpoint.InMemoryLeaderEpochCheckpoint;
-import org.apache.kafka.storage.internals.checkpoint.LeaderEpochCheckpoint;
+import org.apache.kafka.server.util.MockScheduler;
+import org.apache.kafka.storage.internals.checkpoint.LeaderEpochCheckpointFile;
 import org.apache.kafka.storage.internals.epoch.LeaderEpochFileCache;
 import org.apache.kafka.storage.internals.log.EpochEntry;
 import org.apache.kafka.storage.internals.log.FetchDataInfo;
 import org.apache.kafka.storage.internals.log.FetchIsolation;
 import org.apache.kafka.storage.internals.log.LazyIndex;
 import org.apache.kafka.storage.internals.log.LogConfig;
+import org.apache.kafka.storage.internals.log.LogDirFailureChannel;
 import org.apache.kafka.storage.internals.log.LogFileUtils;
 import org.apache.kafka.storage.internals.log.LogSegment;
 import org.apache.kafka.storage.internals.log.OffsetIndex;
@@ -89,16 +90,19 @@ import org.mockito.Mockito;
 import scala.Option;
 import scala.collection.JavaConverters;
 
+import java.io.BufferedReader;
 import java.io.ByteArrayInputStream;
 import java.io.File;
 import java.io.InputStream;
 import java.io.FileInputStream;
 import java.io.IOException;
+import java.io.InputStreamReader;
+import java.io.UncheckedIOException;
 import java.nio.ByteBuffer;
+import java.nio.charset.StandardCharsets;
 import java.nio.file.Files;
 import java.util.ArrayList;
 import java.util.Arrays;
-import java.util.Collection;
 import java.util.Collections;
 import java.util.HashMap;
 import java.util.HashSet;
@@ -197,25 +201,16 @@ public class RemoteLogManagerTest {
     private final EpochEntry epochEntry1 = new EpochEntry(1, 100);
     private final EpochEntry epochEntry2 = new EpochEntry(2, 200);
     private final List<EpochEntry> totalEpochEntries = 
Arrays.asList(epochEntry0, epochEntry1, epochEntry2);
-    private final LeaderEpochCheckpoint checkpoint = new 
LeaderEpochCheckpoint() {
-        List<EpochEntry> epochs = Collections.emptyList();
-
-        @Override
-        public void write(Collection<EpochEntry> epochs, boolean ignored) {
-            this.epochs = new ArrayList<>(epochs);
-        }
-
-        @Override
-        public List<EpochEntry> read() {
-            return epochs;
-        }
-    };
+    private LeaderEpochCheckpointFile checkpoint;
     private final AtomicLong currentLogStartOffset = new AtomicLong(0L);
 
     private UnifiedLog mockLog = mock(UnifiedLog.class);
 
+    private final MockScheduler scheduler = new MockScheduler(time);
+
     @BeforeEach
     void setUp() throws Exception {
+        checkpoint = new LeaderEpochCheckpointFile(TestUtils.tempFile(), new 
LogDirFailureChannel(1));
         topicIds.put(leaderTopicIdPartition.topicPartition().topic(), 
leaderTopicIdPartition.topicId());
         topicIds.put(followerTopicIdPartition.topicPartition().topic(), 
followerTopicIdPartition.topicId());
         Properties props = kafka.utils.TestUtils.createDummyBrokerConfig();
@@ -253,13 +248,11 @@ public class RemoteLogManagerTest {
     @Test
     void testGetLeaderEpochCheckpoint() {
         checkpoint.write(totalEpochEntries);
-        LeaderEpochFileCache cache = new LeaderEpochFileCache(tp, checkpoint);
+        LeaderEpochFileCache cache = new LeaderEpochFileCache(tp, checkpoint, 
scheduler);
         when(mockLog.leaderEpochCache()).thenReturn(Option.apply(cache));
-        InMemoryLeaderEpochCheckpoint inMemoryCheckpoint = 
remoteLogManager.getLeaderEpochCheckpoint(mockLog, 0, 300);
-        assertEquals(totalEpochEntries, inMemoryCheckpoint.read());
+        assertEquals(totalEpochEntries, 
remoteLogManager.getLeaderEpochEntries(mockLog, 0, 300));
 
-        InMemoryLeaderEpochCheckpoint inMemoryCheckpoint2 = 
remoteLogManager.getLeaderEpochCheckpoint(mockLog, 100, 200);
-        List<EpochEntry> epochEntries = inMemoryCheckpoint2.read();
+        List<EpochEntry> epochEntries = 
remoteLogManager.getLeaderEpochEntries(mockLog, 100, 200);
         assertEquals(1, epochEntries.size());
         assertEquals(epochEntry1, epochEntries.get(0));
     }
@@ -271,7 +264,7 @@ public class RemoteLogManagerTest {
                 new EpochEntry(1, 500)
         );
         checkpoint.write(totalEpochEntries);
-        LeaderEpochFileCache cache = new LeaderEpochFileCache(tp, checkpoint);
+        LeaderEpochFileCache cache = new LeaderEpochFileCache(tp, checkpoint, 
scheduler);
         when(mockLog.leaderEpochCache()).thenReturn(Option.apply(cache));
         TopicIdPartition tpId = new TopicIdPartition(Uuid.randomUuid(), tp);
         OffsetAndEpoch offsetAndEpoch = 
remoteLogManager.findHighestRemoteOffset(tpId, mockLog);
@@ -285,7 +278,7 @@ public class RemoteLogManagerTest {
                 new EpochEntry(1, 500)
         );
         checkpoint.write(totalEpochEntries);
-        LeaderEpochFileCache cache = new LeaderEpochFileCache(tp, checkpoint);
+        LeaderEpochFileCache cache = new LeaderEpochFileCache(tp, checkpoint, 
scheduler);
         when(mockLog.leaderEpochCache()).thenReturn(Option.apply(cache));
         TopicIdPartition tpId = new TopicIdPartition(Uuid.randomUuid(), tp);
         when(remoteLogMetadataManager.highestOffsetForEpoch(eq(tpId), 
anyInt())).thenAnswer(ans -> {
@@ -308,7 +301,7 @@ public class RemoteLogManagerTest {
                 new EpochEntry(2, 300)
         );
         checkpoint.write(totalEpochEntries);
-        LeaderEpochFileCache cache = new LeaderEpochFileCache(tp, checkpoint);
+        LeaderEpochFileCache cache = new LeaderEpochFileCache(tp, checkpoint, 
scheduler);
         when(mockLog.leaderEpochCache()).thenReturn(Option.apply(cache));
         TopicIdPartition tpId = new TopicIdPartition(Uuid.randomUuid(), tp);
         when(remoteLogMetadataManager.highestOffsetForEpoch(eq(tpId), 
anyInt())).thenAnswer(ans -> {
@@ -470,7 +463,7 @@ public class RemoteLogManagerTest {
 
         // leader epoch preparation
         checkpoint.write(totalEpochEntries);
-        LeaderEpochFileCache cache = new 
LeaderEpochFileCache(leaderTopicIdPartition.topicPartition(), checkpoint);
+        LeaderEpochFileCache cache = new 
LeaderEpochFileCache(leaderTopicIdPartition.topicPartition(), checkpoint, 
scheduler);
         when(mockLog.leaderEpochCache()).thenReturn(Option.apply(cache));
         
when(remoteLogMetadataManager.highestOffsetForEpoch(any(TopicIdPartition.class),
 anyInt())).thenReturn(Optional.of(-1L));
 
@@ -584,7 +577,7 @@ public class RemoteLogManagerTest {
 
         // leader epoch preparation
         checkpoint.write(totalEpochEntries);
-        LeaderEpochFileCache cache = new 
LeaderEpochFileCache(leaderTopicIdPartition.topicPartition(), checkpoint);
+        LeaderEpochFileCache cache = new 
LeaderEpochFileCache(leaderTopicIdPartition.topicPartition(), checkpoint, 
scheduler);
         when(mockLog.leaderEpochCache()).thenReturn(Option.apply(cache));
         
when(remoteLogMetadataManager.highestOffsetForEpoch(any(TopicIdPartition.class),
 anyInt())).thenReturn(Optional.of(-1L));
 
@@ -684,7 +677,7 @@ public class RemoteLogManagerTest {
 
         // leader epoch preparation
         checkpoint.write(totalEpochEntries);
-        LeaderEpochFileCache cache = new 
LeaderEpochFileCache(leaderTopicIdPartition.topicPartition(), checkpoint);
+        LeaderEpochFileCache cache = new 
LeaderEpochFileCache(leaderTopicIdPartition.topicPartition(), checkpoint, 
scheduler);
         when(mockLog.leaderEpochCache()).thenReturn(Option.apply(cache));
         
when(remoteLogMetadataManager.highestOffsetForEpoch(any(TopicIdPartition.class),
 anyInt())).thenReturn(Optional.of(0L));
 
@@ -803,7 +796,7 @@ public class RemoteLogManagerTest {
 
         // leader epoch preparation
         checkpoint.write(totalEpochEntries);
-        LeaderEpochFileCache cache = new 
LeaderEpochFileCache(leaderTopicIdPartition.topicPartition(), checkpoint);
+        LeaderEpochFileCache cache = new 
LeaderEpochFileCache(leaderTopicIdPartition.topicPartition(), checkpoint, 
scheduler);
         when(mockLog.leaderEpochCache()).thenReturn(Option.apply(cache));
         
when(remoteLogMetadataManager.highestOffsetForEpoch(any(TopicIdPartition.class),
 anyInt()))
                 .thenReturn(Optional.of(0L))
@@ -917,7 +910,7 @@ public class RemoteLogManagerTest {
 
         // leader epoch preparation
         checkpoint.write(totalEpochEntries);
-        LeaderEpochFileCache cache = new 
LeaderEpochFileCache(leaderTopicIdPartition.topicPartition(), checkpoint);
+        LeaderEpochFileCache cache = new 
LeaderEpochFileCache(leaderTopicIdPartition.topicPartition(), checkpoint, 
scheduler);
         when(mockLog.leaderEpochCache()).thenReturn(Option.apply(cache));
         
when(remoteLogMetadataManager.highestOffsetForEpoch(any(TopicIdPartition.class),
 anyInt())).thenReturn(Optional.of(0L));
 
@@ -1067,7 +1060,7 @@ public class RemoteLogManagerTest {
 
         // leader epoch preparation
         checkpoint.write(totalEpochEntries);
-        LeaderEpochFileCache cache = new 
LeaderEpochFileCache(leaderTopicIdPartition.topicPartition(), checkpoint);
+        LeaderEpochFileCache cache = new 
LeaderEpochFileCache(leaderTopicIdPartition.topicPartition(), checkpoint, 
scheduler);
         when(mockLog.leaderEpochCache()).thenReturn(Option.apply(cache));
         
when(remoteLogMetadataManager.highestOffsetForEpoch(any(TopicIdPartition.class),
 anyInt())).thenReturn(Optional.of(0L));
 
@@ -1140,7 +1133,7 @@ public class RemoteLogManagerTest {
 
         // leader epoch preparation
         checkpoint.write(totalEpochEntries);
-        LeaderEpochFileCache cache = new 
LeaderEpochFileCache(leaderTopicIdPartition.topicPartition(), checkpoint);
+        LeaderEpochFileCache cache = new 
LeaderEpochFileCache(leaderTopicIdPartition.topicPartition(), checkpoint, 
scheduler);
         when(mockLog.leaderEpochCache()).thenReturn(Option.apply(cache));
         
when(remoteLogMetadataManager.highestOffsetForEpoch(any(TopicIdPartition.class),
 anyInt())).thenReturn(Optional.of(0L));
 
@@ -1176,7 +1169,7 @@ public class RemoteLogManagerTest {
 
         // leader epoch preparation
         checkpoint.write(totalEpochEntries);
-        LeaderEpochFileCache cache = new 
LeaderEpochFileCache(leaderTopicIdPartition.topicPartition(), checkpoint);
+        LeaderEpochFileCache cache = new 
LeaderEpochFileCache(leaderTopicIdPartition.topicPartition(), checkpoint, 
scheduler);
         when(mockLog.leaderEpochCache()).thenReturn(Option.apply(cache));
 
         // Throw a retryable exception so indicate that the remote log 
metadata manager is not initialized yet
@@ -1258,9 +1251,7 @@ public class RemoteLogManagerTest {
         assertEquals(tempFile.getAbsolutePath(), 
logSegmentData.logSegment().toAbsolutePath().toString());
         assertEquals(mockProducerSnapshotIndex.getAbsolutePath(), 
logSegmentData.producerSnapshotIndex().toAbsolutePath().toString());
 
-        InMemoryLeaderEpochCheckpoint inMemoryLeaderEpochCheckpoint = new 
InMemoryLeaderEpochCheckpoint();
-        inMemoryLeaderEpochCheckpoint.write(expectedLeaderEpoch);
-        assertEquals(inMemoryLeaderEpochCheckpoint.readAsByteBuffer(), 
logSegmentData.leaderEpochIndex());
+        
assertEquals(RemoteLogManager.epochEntriesAsByteBuffer(expectedLeaderEpoch), 
logSegmentData.leaderEpochIndex());
     }
 
     @Test
@@ -1379,7 +1370,7 @@ public class RemoteLogManagerTest {
         TreeMap<Integer, Long> validSegmentEpochs = new TreeMap<>();
         validSegmentEpochs.put(targetLeaderEpoch, startOffset);
 
-        LeaderEpochFileCache leaderEpochFileCache = new 
LeaderEpochFileCache(tp, checkpoint);
+        LeaderEpochFileCache leaderEpochFileCache = new 
LeaderEpochFileCache(tp, checkpoint, scheduler);
         leaderEpochFileCache.assign(4, 99L);
         leaderEpochFileCache.assign(5, 99L);
         leaderEpochFileCache.assign(targetLeaderEpoch, startOffset);
@@ -1414,7 +1405,7 @@ public class RemoteLogManagerTest {
         validSegmentEpochs.put(targetLeaderEpoch - 1, startOffset - 1); // 
invalid epochs not aligning with leader epoch cache
         validSegmentEpochs.put(targetLeaderEpoch, startOffset);
 
-        LeaderEpochFileCache leaderEpochFileCache = new 
LeaderEpochFileCache(tp, checkpoint);
+        LeaderEpochFileCache leaderEpochFileCache = new 
LeaderEpochFileCache(tp, checkpoint, scheduler);
         leaderEpochFileCache.assign(4, 99L);
         leaderEpochFileCache.assign(5, 99L);
         leaderEpochFileCache.assign(targetLeaderEpoch, startOffset);
@@ -1445,7 +1436,7 @@ public class RemoteLogManagerTest {
         TreeMap<Integer, Long> validSegmentEpochs = new TreeMap<>();
         validSegmentEpochs.put(targetLeaderEpoch, startOffset);
 
-        LeaderEpochFileCache leaderEpochFileCache = new 
LeaderEpochFileCache(tp, checkpoint);
+        LeaderEpochFileCache leaderEpochFileCache = new 
LeaderEpochFileCache(tp, checkpoint, scheduler);
         leaderEpochFileCache.assign(4, 99L);
         leaderEpochFileCache.assign(5, 99L);
         leaderEpochFileCache.assign(targetLeaderEpoch, startOffset);
@@ -1902,7 +1893,7 @@ public class RemoteLogManagerTest {
         epochEntries.add(new EpochEntry(2, 550L));
         checkpoint.write(epochEntries);
 
-        LeaderEpochFileCache cache = new 
LeaderEpochFileCache(leaderTopicIdPartition.topicPartition(), checkpoint);
+        LeaderEpochFileCache cache = new 
LeaderEpochFileCache(leaderTopicIdPartition.topicPartition(), checkpoint, 
scheduler);
         when(mockLog.leaderEpochCache()).thenReturn(Option.apply(cache));
 
         long timestamp = time.milliseconds();
@@ -1940,7 +1931,7 @@ public class RemoteLogManagerTest {
         epochEntries.add(new EpochEntry(2, 550L));
         checkpoint.write(epochEntries);
 
-        LeaderEpochFileCache cache = new 
LeaderEpochFileCache(leaderTopicIdPartition.topicPartition(), checkpoint);
+        LeaderEpochFileCache cache = new 
LeaderEpochFileCache(leaderTopicIdPartition.topicPartition(), checkpoint, 
scheduler);
         when(mockLog.leaderEpochCache()).thenReturn(Option.apply(cache));
         when(mockLog.localLogStartOffset()).thenReturn(250L);
         
when(remoteLogMetadataManager.listRemoteLogSegments(eq(leaderTopicIdPartition), 
anyInt()))
@@ -1965,7 +1956,7 @@ public class RemoteLogManagerTest {
         epochEntries.add(new EpochEntry(2, 550L));
         checkpoint.write(epochEntries);
 
-        LeaderEpochFileCache cache = new 
LeaderEpochFileCache(leaderTopicIdPartition.topicPartition(), checkpoint);
+        LeaderEpochFileCache cache = new 
LeaderEpochFileCache(leaderTopicIdPartition.topicPartition(), checkpoint, 
scheduler);
         when(mockLog.leaderEpochCache()).thenReturn(Option.apply(cache));
 
         RemoteLogSegmentMetadata metadata = 
mock(RemoteLogSegmentMetadata.class);
@@ -2008,7 +1999,7 @@ public class RemoteLogManagerTest {
 
         List<EpochEntry> epochEntries = Collections.singletonList(epochEntry0);
         checkpoint.write(epochEntries);
-        LeaderEpochFileCache cache = new LeaderEpochFileCache(tp, checkpoint);
+        LeaderEpochFileCache cache = new LeaderEpochFileCache(tp, checkpoint, 
scheduler);
         when(mockLog.leaderEpochCache()).thenReturn(Option.apply(cache));
 
         
when(mockLog.topicPartition()).thenReturn(leaderTopicIdPartition.topicPartition());
@@ -2061,7 +2052,7 @@ public class RemoteLogManagerTest {
 
         List<EpochEntry> epochEntries = Collections.singletonList(epochEntry0);
         checkpoint.write(epochEntries);
-        LeaderEpochFileCache cache = new LeaderEpochFileCache(tp, checkpoint);
+        LeaderEpochFileCache cache = new LeaderEpochFileCache(tp, checkpoint, 
scheduler);
         when(mockLog.leaderEpochCache()).thenReturn(Option.apply(cache));
 
         
when(mockLog.topicPartition()).thenReturn(leaderTopicIdPartition.topicPartition());
@@ -2132,7 +2123,7 @@ public class RemoteLogManagerTest {
                 .thenAnswer(ans -> metadataList.iterator());
 
         checkpoint.write(epochEntries);
-        LeaderEpochFileCache cache = new LeaderEpochFileCache(tp, checkpoint);
+        LeaderEpochFileCache cache = new LeaderEpochFileCache(tp, checkpoint, 
scheduler);
         when(mockLog.leaderEpochCache()).thenReturn(Option.apply(cache));
 
         Map<String, Long> logProps = new HashMap<>();
@@ -2194,7 +2185,7 @@ public class RemoteLogManagerTest {
 
         List<EpochEntry> epochEntries = Collections.singletonList(epochEntry0);
         checkpoint.write(epochEntries);
-        LeaderEpochFileCache cache = new LeaderEpochFileCache(tp, checkpoint);
+        LeaderEpochFileCache cache = new LeaderEpochFileCache(tp, checkpoint, 
scheduler);
         when(mockLog.leaderEpochCache()).thenReturn(Option.apply(cache));
 
         
when(mockLog.topicPartition()).thenReturn(leaderTopicIdPartition.topicPartition());
@@ -2247,7 +2238,7 @@ public class RemoteLogManagerTest {
                 new EpochEntry(4, 100L)
         );
         checkpoint.write(epochEntries);
-        LeaderEpochFileCache cache = new LeaderEpochFileCache(tp, checkpoint);
+        LeaderEpochFileCache cache = new LeaderEpochFileCache(tp, checkpoint, 
scheduler);
         int currentLeaderEpoch = epochEntries.get(epochEntries.size() - 
1).epoch;
 
         long localLogSegmentsSize = 512L;
@@ -2285,7 +2276,7 @@ public class RemoteLogManagerTest {
                 new EpochEntry(4, 100L)
         );
         checkpoint.write(epochEntries);
-        LeaderEpochFileCache cache = new LeaderEpochFileCache(tp, checkpoint);
+        LeaderEpochFileCache cache = new LeaderEpochFileCache(tp, checkpoint, 
scheduler);
         int currentLeaderEpoch = epochEntries.get(epochEntries.size() - 
1).epoch;
 
         long localLogSegmentsSize = 512L;
@@ -2372,7 +2363,7 @@ public class RemoteLogManagerTest {
                     .thenReturn(remoteLogSegmentMetadatas.iterator());
 
             checkpoint.write(epochEntries);
-            LeaderEpochFileCache cache = new LeaderEpochFileCache(tp, 
checkpoint);
+            LeaderEpochFileCache cache = new LeaderEpochFileCache(tp, 
checkpoint, scheduler);
             when(mockLog.leaderEpochCache()).thenReturn(Option.apply(cache));
 
             Map<String, Long> logProps = new HashMap<>();
@@ -2447,11 +2438,17 @@ public class RemoteLogManagerTest {
     private Map<Integer, Long> truncateAndGetLeaderEpochs(List<EpochEntry> 
entries,
                                                           Long startOffset,
                                                           Long endOffset) {
-        InMemoryLeaderEpochCheckpoint myCheckpoint = new 
InMemoryLeaderEpochCheckpoint();
+        LeaderEpochCheckpointFile myCheckpoint;
+        try {
+            myCheckpoint = new LeaderEpochCheckpointFile(
+                    TestUtils.tempFile(), new LogDirFailureChannel(1));
+        } catch (IOException e) {
+            throw new UncheckedIOException(e);
+        }
         myCheckpoint.write(entries);
-        LeaderEpochFileCache cache = new LeaderEpochFileCache(null, 
myCheckpoint);
-        cache.truncateFromStart(startOffset);
-        cache.truncateFromEnd(endOffset);
+        LeaderEpochFileCache cache = new LeaderEpochFileCache(null, 
myCheckpoint, scheduler);
+        cache.truncateFromStartAsyncFlush(startOffset);
+        cache.truncateFromEndAsyncFlush(endOffset);
         return myCheckpoint.read().stream().collect(Collectors.toMap(e -> 
e.epoch, e -> e.startOffset));
     }
 
@@ -2678,7 +2675,7 @@ public class RemoteLogManagerTest {
 
         }
     }
-    
+
     @Test
     public void testCopyQuotaManagerConfig() {
         Properties defaultProps = new Properties();
@@ -2698,7 +2695,7 @@ public class RemoteLogManagerTest {
         assertEquals(31, rlmCopyQuotaManagerConfig.numQuotaSamples());
         assertEquals(1, rlmCopyQuotaManagerConfig.quotaWindowSizeSeconds());
     }
-    
+
     @Test
     public void testFetchQuotaManagerConfig() {
         Properties defaultProps = new Properties();
@@ -2719,6 +2716,21 @@ public class RemoteLogManagerTest {
         assertEquals(1, rlmFetchQuotaManagerConfig.quotaWindowSizeSeconds());
     }
 
+    @Test
+    public void testEpochEntriesAsByteBuffer() throws Exception {
+        int expectedEpoch = 0;
+        long expectedStartOffset = 1L;
+        int expectedVersion = 0;
+        List<EpochEntry> epochs = Arrays.asList(new EpochEntry(expectedEpoch, 
expectedStartOffset));
+        ByteBuffer buffer = RemoteLogManager.epochEntriesAsByteBuffer(epochs);
+        BufferedReader bufferedReader = new BufferedReader(new 
InputStreamReader(new ByteArrayInputStream(buffer.array()), 
StandardCharsets.UTF_8));
+
+        assertEquals(String.valueOf(expectedVersion), 
bufferedReader.readLine());
+        assertEquals(String.valueOf(epochs.size()), bufferedReader.readLine());
+        assertEquals(expectedEpoch + " " + expectedStartOffset, 
bufferedReader.readLine());
+    }
+
+
     private Partition mockPartition(TopicIdPartition topicIdPartition) {
         TopicPartition tp = topicIdPartition.topicPartition();
         Partition partition = mock(Partition.class);
diff --git a/core/src/test/scala/unit/kafka/cluster/PartitionLockTest.scala 
b/core/src/test/scala/unit/kafka/cluster/PartitionLockTest.scala
index 32ddfc6418d..2e9bc068978 100644
--- a/core/src/test/scala/unit/kafka/cluster/PartitionLockTest.scala
+++ b/core/src/test/scala/unit/kafka/cluster/PartitionLockTest.scala
@@ -299,7 +299,8 @@ class PartitionLockTest extends Logging {
         val log = super.createLog(isNew, isFutureReplica, offsetCheckpoints, 
None, None)
         val logDirFailureChannel = new LogDirFailureChannel(1)
         val segments = new LogSegments(log.topicPartition)
-        val leaderEpochCache = UnifiedLog.maybeCreateLeaderEpochCache(log.dir, 
log.topicPartition, logDirFailureChannel, log.config.recordVersion, "")
+        val leaderEpochCache = UnifiedLog.maybeCreateLeaderEpochCache(
+          log.dir, log.topicPartition, logDirFailureChannel, 
log.config.recordVersion, "", None, mockTime.scheduler)
         val maxTransactionTimeout = 5 * 60 * 1000
         val producerStateManagerConfig = new 
ProducerStateManagerConfig(TransactionLogConfigs.PRODUCER_ID_EXPIRATION_MS_DEFAULT,
 false)
         val producerStateManager = new ProducerStateManager(
diff --git a/core/src/test/scala/unit/kafka/cluster/PartitionTest.scala 
b/core/src/test/scala/unit/kafka/cluster/PartitionTest.scala
index 2134dcfaaa0..6dc6cc2a3c1 100644
--- a/core/src/test/scala/unit/kafka/cluster/PartitionTest.scala
+++ b/core/src/test/scala/unit/kafka/cluster/PartitionTest.scala
@@ -434,7 +434,8 @@ class PartitionTest extends AbstractPartitionTest {
         val log = super.createLog(isNew, isFutureReplica, offsetCheckpoints, 
None, None)
         val logDirFailureChannel = new LogDirFailureChannel(1)
         val segments = new LogSegments(log.topicPartition)
-        val leaderEpochCache = UnifiedLog.maybeCreateLeaderEpochCache(log.dir, 
log.topicPartition, logDirFailureChannel, log.config.recordVersion, "")
+        val leaderEpochCache = UnifiedLog.maybeCreateLeaderEpochCache(
+          log.dir, log.topicPartition, logDirFailureChannel, 
log.config.recordVersion, "", None, time.scheduler)
         val maxTransactionTimeoutMs = 5 * 60 * 1000
         val producerStateManagerConfig = new 
ProducerStateManagerConfig(TransactionLogConfigs.PRODUCER_ID_EXPIRATION_MS_DEFAULT,
 true)
         val producerStateManager = new ProducerStateManager(
diff --git a/core/src/test/scala/unit/kafka/log/LogCleanerManagerTest.scala 
b/core/src/test/scala/unit/kafka/log/LogCleanerManagerTest.scala
index f17c724066f..bdbcac462b8 100644
--- a/core/src/test/scala/unit/kafka/log/LogCleanerManagerTest.scala
+++ b/core/src/test/scala/unit/kafka/log/LogCleanerManagerTest.scala
@@ -108,7 +108,8 @@ class LogCleanerManagerTest extends Logging {
     val maxTransactionTimeoutMs = 5 * 60 * 1000
     val producerIdExpirationCheckIntervalMs = 
TransactionLogConfigs.PRODUCER_ID_EXPIRATION_CHECK_INTERVAL_MS_DEFAULT
     val segments = new LogSegments(tp)
-    val leaderEpochCache = UnifiedLog.maybeCreateLeaderEpochCache(tpDir, 
topicPartition, logDirFailureChannel, config.recordVersion, "")
+    val leaderEpochCache = UnifiedLog.maybeCreateLeaderEpochCache(
+      tpDir, topicPartition, logDirFailureChannel, config.recordVersion, "", 
None, time.scheduler)
     val producerStateManager = new ProducerStateManager(topicPartition, tpDir, 
maxTransactionTimeoutMs, producerStateManagerConfig, time)
     val offsets = new LogLoader(
       tpDir,
diff --git a/core/src/test/scala/unit/kafka/log/LogCleanerTest.scala 
b/core/src/test/scala/unit/kafka/log/LogCleanerTest.scala
index b61eb28530c..ee62b3f0f21 100644
--- a/core/src/test/scala/unit/kafka/log/LogCleanerTest.scala
+++ b/core/src/test/scala/unit/kafka/log/LogCleanerTest.scala
@@ -92,7 +92,7 @@ class LogCleanerTest extends Logging {
       val mockMetricsGroup = mockMetricsGroupCtor.constructed.get(0)
       val numMetricsRegistered = LogCleaner.MetricNames.size
       verify(mockMetricsGroup, 
times(numMetricsRegistered)).newGauge(anyString(), any())
-      
+
       // verify that each metric in `LogCleaner` is removed
       LogCleaner.MetricNames.foreach(verify(mockMetricsGroup).removeMetric(_))
 
@@ -188,7 +188,8 @@ class LogCleanerTest extends Logging {
     val maxTransactionTimeoutMs = 5 * 60 * 1000
     val producerIdExpirationCheckIntervalMs = 
TransactionLogConfigs.PRODUCER_ID_EXPIRATION_CHECK_INTERVAL_MS_DEFAULT
     val logSegments = new LogSegments(topicPartition)
-    val leaderEpochCache = UnifiedLog.maybeCreateLeaderEpochCache(dir, 
topicPartition, logDirFailureChannel, config.recordVersion, "")
+    val leaderEpochCache = UnifiedLog.maybeCreateLeaderEpochCache(
+      dir, topicPartition, logDirFailureChannel, config.recordVersion, "", 
None, time.scheduler)
     val producerStateManager = new ProducerStateManager(topicPartition, dir,
       maxTransactionTimeoutMs, producerStateManagerConfig, time)
     val offsets = new LogLoader(
diff --git a/core/src/test/scala/unit/kafka/log/LogLoaderTest.scala 
b/core/src/test/scala/unit/kafka/log/LogLoaderTest.scala
index 1a781a93ea6..7838da54173 100644
--- a/core/src/test/scala/unit/kafka/log/LogLoaderTest.scala
+++ b/core/src/test/scala/unit/kafka/log/LogLoaderTest.scala
@@ -154,7 +154,8 @@ class LogLoaderTest {
           val logStartOffset = logStartOffsets.getOrElse(topicPartition, 0L)
           val logDirFailureChannel: LogDirFailureChannel = new 
LogDirFailureChannel(1)
           val segments = new LogSegments(topicPartition)
-          val leaderEpochCache = 
UnifiedLog.maybeCreateLeaderEpochCache(logDir, topicPartition, 
logDirFailureChannel, config.recordVersion, "")
+          val leaderEpochCache = UnifiedLog.maybeCreateLeaderEpochCache(
+            logDir, topicPartition, logDirFailureChannel, 
config.recordVersion, "", None, time.scheduler)
           val producerStateManager = new ProducerStateManager(topicPartition, 
logDir,
             this.maxTransactionTimeoutMs, this.producerStateManagerConfig, 
time)
           val logLoader = new LogLoader(logDir, topicPartition, config, 
time.scheduler, time,
@@ -367,7 +368,8 @@ class LogLoaderTest {
           super.add(wrapper)
         }
       }
-      val leaderEpochCache = UnifiedLog.maybeCreateLeaderEpochCache(logDir, 
topicPartition, logDirFailureChannel, logConfig.recordVersion, "")
+      val leaderEpochCache = UnifiedLog.maybeCreateLeaderEpochCache(
+        logDir, topicPartition, logDirFailureChannel, logConfig.recordVersion, 
"", None, mockTime.scheduler)
       val producerStateManager = new ProducerStateManager(topicPartition, 
logDir,
         maxTransactionTimeoutMs, producerStateManagerConfig, mockTime)
       val logLoader = new LogLoader(
@@ -431,7 +433,8 @@ class LogLoaderTest {
     val logDirFailureChannel: LogDirFailureChannel = new 
LogDirFailureChannel(1)
     val config = new LogConfig(new Properties())
     val segments = new LogSegments(topicPartition)
-    val leaderEpochCache = UnifiedLog.maybeCreateLeaderEpochCache(logDir, 
topicPartition, logDirFailureChannel, config.recordVersion, "")
+    val leaderEpochCache = UnifiedLog.maybeCreateLeaderEpochCache(
+      logDir, topicPartition, logDirFailureChannel, config.recordVersion, "", 
None, mockTime.scheduler)
     val offsets = new LogLoader(
       logDir,
       topicPartition,
@@ -540,7 +543,8 @@ class LogLoaderTest {
     val config = new LogConfig(logProps)
     val logDirFailureChannel = null
     val segments = new LogSegments(topicPartition)
-    val leaderEpochCache = UnifiedLog.maybeCreateLeaderEpochCache(logDir, 
topicPartition, logDirFailureChannel, config.recordVersion, "")
+    val leaderEpochCache = UnifiedLog.maybeCreateLeaderEpochCache(
+      logDir, topicPartition, logDirFailureChannel, config.recordVersion, "", 
None, mockTime.scheduler)
     val offsets = new LogLoader(
       logDir,
       topicPartition,
@@ -594,7 +598,8 @@ class LogLoaderTest {
     val config = new LogConfig(logProps)
     val logDirFailureChannel = null
     val segments = new LogSegments(topicPartition)
-    val leaderEpochCache = UnifiedLog.maybeCreateLeaderEpochCache(logDir, 
topicPartition, logDirFailureChannel, config.recordVersion, "")
+    val leaderEpochCache = UnifiedLog.maybeCreateLeaderEpochCache(
+      logDir, topicPartition, logDirFailureChannel, config.recordVersion, "", 
None, mockTime.scheduler)
     val offsets = new LogLoader(
       logDir,
       topicPartition,
@@ -647,7 +652,8 @@ class LogLoaderTest {
     val config = new LogConfig(logProps)
     val logDirFailureChannel = null
     val segments = new LogSegments(topicPartition)
-    val leaderEpochCache = UnifiedLog.maybeCreateLeaderEpochCache(logDir, 
topicPartition, logDirFailureChannel, config.recordVersion, "")
+    val leaderEpochCache = UnifiedLog.maybeCreateLeaderEpochCache(
+      logDir, topicPartition, logDirFailureChannel, config.recordVersion, "", 
None, mockTime.scheduler)
     val offsets = new LogLoader(
       logDir,
       topicPartition,
@@ -1387,7 +1393,7 @@ class LogLoaderTest {
     assertEquals(java.util.Arrays.asList(new EpochEntry(1, 0), new 
EpochEntry(2, 1), new EpochEntry(3, 3)), leaderEpochCache.epochEntries)
 
     // deliberately remove some of the epoch entries
-    leaderEpochCache.truncateFromEnd(2)
+    leaderEpochCache.truncateFromEndAsyncFlush(2)
     assertNotEquals(java.util.Arrays.asList(new EpochEntry(1, 0), new 
EpochEntry(2, 1), new EpochEntry(3, 3)), leaderEpochCache.epochEntries)
     log.close()
 
@@ -1789,7 +1795,8 @@ class LogLoaderTest {
     log.logSegments.forEach(segment => segments.add(segment))
     assertEquals(5, segments.firstSegment.get.baseOffset)
 
-    val leaderEpochCache = UnifiedLog.maybeCreateLeaderEpochCache(logDir, 
topicPartition, logDirFailureChannel, logConfig.recordVersion, "")
+    val leaderEpochCache = UnifiedLog.maybeCreateLeaderEpochCache(
+      logDir, topicPartition, logDirFailureChannel, logConfig.recordVersion, 
"", None, mockTime.scheduler)
     val offsets = new LogLoader(
       logDir,
       topicPartition,
diff --git a/core/src/test/scala/unit/kafka/log/LogSegmentTest.scala 
b/core/src/test/scala/unit/kafka/log/LogSegmentTest.scala
index b559c192790..e2272941ab3 100644
--- a/core/src/test/scala/unit/kafka/log/LogSegmentTest.scala
+++ b/core/src/test/scala/unit/kafka/log/LogSegmentTest.scala
@@ -24,7 +24,8 @@ import org.apache.kafka.common.config.TopicConfig
 import org.apache.kafka.common.record._
 import org.apache.kafka.common.utils.{MockTime, Time, Utils}
 import org.apache.kafka.coordinator.transaction.TransactionLogConfigs
-import org.apache.kafka.storage.internals.checkpoint.LeaderEpochCheckpoint
+import org.apache.kafka.server.util.MockScheduler
+import org.apache.kafka.storage.internals.checkpoint.LeaderEpochCheckpointFile
 import org.apache.kafka.storage.internals.epoch.LeaderEpochFileCache
 import org.apache.kafka.storage.internals.log._
 import org.junit.jupiter.api.Assertions._
@@ -33,7 +34,6 @@ import org.junit.jupiter.params.ParameterizedTest
 import org.junit.jupiter.params.provider.{CsvSource, ValueSource}
 
 import java.io.{File, RandomAccessFile}
-import java.util
 import java.util.{Optional, OptionalLong}
 import scala.collection._
 import scala.jdk.CollectionConverters._
@@ -431,17 +431,9 @@ class LogSegmentTest {
   def testRecoveryRebuildsEpochCache(): Unit = {
     val seg = createSegment(0)
 
-    val checkpoint: LeaderEpochCheckpoint = new LeaderEpochCheckpoint {
-      private var epochs = Seq.empty[EpochEntry]
+    val checkpoint: LeaderEpochCheckpointFile = new 
LeaderEpochCheckpointFile(TestUtils.tempFile(), new LogDirFailureChannel(1))
 
-      override def write(epochs: util.Collection[EpochEntry], ignored: 
Boolean): Unit = {
-        this.epochs = epochs.asScala.toSeq
-      }
-
-      override def read(): java.util.List[EpochEntry] = this.epochs.asJava
-    }
-
-    val cache = new LeaderEpochFileCache(topicPartition, checkpoint)
+    val cache = new LeaderEpochFileCache(topicPartition, checkpoint, new 
MockScheduler(new MockTime()))
     seg.append(105L, RecordBatch.NO_TIMESTAMP, 104L, 
MemoryRecords.withRecords(104L, Compression.NONE, 0,
         new SimpleRecord("a".getBytes), new SimpleRecord("b".getBytes)))
 
diff --git a/core/src/test/scala/unit/kafka/server/ReplicaManagerTest.scala 
b/core/src/test/scala/unit/kafka/server/ReplicaManagerTest.scala
index 65cbcc7bd70..bdb4e6bf7e0 100644
--- a/core/src/test/scala/unit/kafka/server/ReplicaManagerTest.scala
+++ b/core/src/test/scala/unit/kafka/server/ReplicaManagerTest.scala
@@ -2911,7 +2911,8 @@ class ReplicaManagerTest {
     val maxTransactionTimeoutMs = 30000
     val maxProducerIdExpirationMs = 30000
     val segments = new LogSegments(tp)
-    val leaderEpochCache = UnifiedLog.maybeCreateLeaderEpochCache(logDir, tp, 
mockLogDirFailureChannel, logConfig.recordVersion, "")
+    val leaderEpochCache = UnifiedLog.maybeCreateLeaderEpochCache(
+      logDir, tp, mockLogDirFailureChannel, logConfig.recordVersion, "", None, 
time.scheduler)
     val producerStateManager = new ProducerStateManager(tp, logDir,
       maxTransactionTimeoutMs, new 
ProducerStateManagerConfig(maxProducerIdExpirationMs, true), time)
     val offsets = new LogLoader(
@@ -6523,7 +6524,7 @@ class ReplicaManagerTest {
       partition.createLogIfNotExists(isNew = false, isFutureReplica = false, 
offsetCheckpoints, None)
 
       val leaderAndIsr = LeaderAndIsr(0, 1, List(0, 1), 
LeaderRecoveryState.RECOVERED, LeaderAndIsr.InitialPartitionEpoch)
-      val becomeLeaderRequest = makeLeaderAndIsrRequest(topicIds(tp0.topic), 
tp0, Seq(0, 1), leaderAndIsr)  
+      val becomeLeaderRequest = makeLeaderAndIsrRequest(topicIds(tp0.topic), 
tp0, Seq(0, 1), leaderAndIsr)
 
       replicaManager.becomeLeaderOrFollower(1, becomeLeaderRequest, (_, _) => 
())
       verifyRLMOnLeadershipChange(Collections.singleton(partition), 
Collections.emptySet())
diff --git 
a/core/src/test/scala/unit/kafka/server/checkpoints/InMemoryLeaderEpochCheckpointTest.scala
 
b/core/src/test/scala/unit/kafka/server/checkpoints/InMemoryLeaderEpochCheckpointTest.scala
deleted file mode 100644
index 3af126f5c55..00000000000
--- 
a/core/src/test/scala/unit/kafka/server/checkpoints/InMemoryLeaderEpochCheckpointTest.scala
+++ /dev/null
@@ -1,58 +0,0 @@
-/**
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements.  See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License.  You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-package kafka.server.checkpoints
-
-import 
org.apache.kafka.storage.internals.checkpoint.InMemoryLeaderEpochCheckpoint
-import org.apache.kafka.storage.internals.log.EpochEntry
-import org.junit.jupiter.api.Assertions.assertEquals
-import org.junit.jupiter.api.Test
-
-import java.io.{BufferedReader, ByteArrayInputStream, InputStreamReader}
-import java.nio.charset.StandardCharsets
-
-class InMemoryLeaderEpochCheckpointTest {
-
-  @Test
-  def shouldAppendNewEntry(): Unit = {
-    val checkpoint = new InMemoryLeaderEpochCheckpoint()
-    val epochs = java.util.Arrays.asList(new EpochEntry(0, 1L), new 
EpochEntry(1, 2L), new EpochEntry(2, 3L))
-    checkpoint.write(epochs)
-    assertEquals(epochs, checkpoint.read())
-
-    val epochs2 = java.util.Arrays.asList(new EpochEntry(3, 4L), new 
EpochEntry(4, 5L))
-    checkpoint.write(epochs2)
-
-    assertEquals(epochs2, checkpoint.read())
-  }
-
-  @Test
-  def testReadAsByteBuffer(): Unit = {
-    val checkpoint = new InMemoryLeaderEpochCheckpoint()
-    val expectedEpoch = 0
-    val expectedStartOffset = 1L
-    val expectedVersion = 0
-    val epochs = java.util.Arrays.asList(new EpochEntry(expectedEpoch, 
expectedStartOffset))
-    checkpoint.write(epochs)
-    assertEquals(epochs, checkpoint.read())
-    val buffer = checkpoint.readAsByteBuffer()
-
-    val bufferedReader = new BufferedReader(new InputStreamReader(new 
ByteArrayInputStream(buffer.array()), StandardCharsets.UTF_8))
-    assertEquals(expectedVersion.toString, bufferedReader.readLine())
-    assertEquals(epochs.size().toString, bufferedReader.readLine())
-    assertEquals(s"$expectedEpoch $expectedStartOffset", 
bufferedReader.readLine())
-  }
-}
diff --git 
a/core/src/test/scala/unit/kafka/server/checkpoints/OffsetCheckpointFileWithFailureHandlerTest.scala
 
b/core/src/test/scala/unit/kafka/server/checkpoints/OffsetCheckpointFileWithFailureHandlerTest.scala
index a7e370d7f40..7808cedb075 100644
--- 
a/core/src/test/scala/unit/kafka/server/checkpoints/OffsetCheckpointFileWithFailureHandlerTest.scala
+++ 
b/core/src/test/scala/unit/kafka/server/checkpoints/OffsetCheckpointFileWithFailureHandlerTest.scala
@@ -19,12 +19,13 @@ package kafka.server.checkpoints
 import kafka.utils.{Logging, TestUtils}
 import org.apache.kafka.common.TopicPartition
 import org.apache.kafka.common.errors.KafkaStorageException
-import 
org.apache.kafka.storage.internals.checkpoint.CheckpointFileWithFailureHandler
-import org.apache.kafka.storage.internals.log.LogDirFailureChannel
+import 
org.apache.kafka.storage.internals.checkpoint.{CheckpointFileWithFailureHandler,
 LeaderEpochCheckpointFile}
+import org.apache.kafka.storage.internals.log.{EpochEntry, 
LogDirFailureChannel}
 import org.junit.jupiter.api.Assertions._
 import org.junit.jupiter.api.Test
 import org.mockito.Mockito
 
+import java.io.File
 import java.util.Collections
 import scala.collection.Map
 
@@ -97,7 +98,7 @@ class OffsetCheckpointFileWithFailureHandlerTest extends 
Logging {
     val logDirFailureChannel = new LogDirFailureChannel(10)
     val checkpointFile = new CheckpointFileWithFailureHandler(file, 
OffsetCheckpointFile.CurrentVersion + 1,
       OffsetCheckpointFile.Formatter, logDirFailureChannel, file.getParent)
-    checkpointFile.write(Collections.singletonList(new TopicPartition("foo", 
5) -> 10L), true)
+    checkpointFile.write(Collections.singletonList(new TopicPartition("foo", 
5) -> 10L))
     assertThrows(classOf[KafkaStorageException], () => new 
OffsetCheckpointFile(checkpointFile.file, logDirFailureChannel).read())
   }
 
@@ -133,4 +134,15 @@ class OffsetCheckpointFileWithFailureHandlerTest extends 
Logging {
     assertThrows(classOf[IllegalArgumentException], () => 
lazyCheckpoints.fetch("/invalid/kafka-logs", new TopicPartition("foo", 0)))
   }
 
+  @Test
+  def testWriteIfDirExistsShouldNotThrowWhenDirNotExists(): Unit = {
+    val dir = TestUtils.tempDir()
+    val file = dir.toPath.resolve("test-checkpoint").toFile
+    val logDirFailureChannel = new LogDirFailureChannel(10)
+    val checkpointFile = new CheckpointFileWithFailureHandler(file, 0,
+      LeaderEpochCheckpointFile.FORMATTER, logDirFailureChannel, 
file.getParent)
+
+    dir.renameTo(new File(dir.getAbsolutePath + "-renamed"))
+    checkpointFile.writeIfDirExists(Collections.singletonList(new 
EpochEntry(1, 42)))
+  }
 }
diff --git 
a/core/src/test/scala/unit/kafka/server/epoch/LeaderEpochFileCacheTest.scala 
b/core/src/test/scala/unit/kafka/server/epoch/LeaderEpochFileCacheTest.scala
index 05041f39709..6f6d0bdbda5 100644
--- a/core/src/test/scala/unit/kafka/server/epoch/LeaderEpochFileCacheTest.scala
+++ b/core/src/test/scala/unit/kafka/server/epoch/LeaderEpochFileCacheTest.scala
@@ -20,15 +20,15 @@ package kafka.server.epoch
 import kafka.utils.TestUtils
 import org.apache.kafka.common.TopicPartition
 import 
org.apache.kafka.common.requests.OffsetsForLeaderEpochResponse.{UNDEFINED_EPOCH,
 UNDEFINED_EPOCH_OFFSET}
-import org.apache.kafka.storage.internals.checkpoint.{LeaderEpochCheckpoint, 
LeaderEpochCheckpointFile}
+import org.apache.kafka.server.util.MockTime
+import org.apache.kafka.storage.internals.checkpoint.LeaderEpochCheckpointFile
 import org.apache.kafka.storage.internals.epoch.LeaderEpochFileCache
 import org.apache.kafka.storage.internals.log.{EpochEntry, 
LogDirFailureChannel}
 import org.junit.jupiter.api.Assertions._
 import org.junit.jupiter.api.Test
 
 import java.io.File
-import java.util.{Collections, OptionalInt, Optional}
-import scala.collection.Seq
+import java.util.{Collections, Optional, OptionalInt}
 import scala.jdk.CollectionConverters._
 
 /**
@@ -36,13 +36,10 @@ import scala.jdk.CollectionConverters._
   */
 class LeaderEpochFileCacheTest {
   val tp = new TopicPartition("TestTopic", 5)
-  private val checkpoint: LeaderEpochCheckpoint = new LeaderEpochCheckpoint {
-    private var epochs: Seq[EpochEntry] = Seq()
-    override def write(epochs: java.util.Collection[EpochEntry], ignored: 
Boolean): Unit = this.epochs = epochs.asScala.toSeq
-    override def read(): java.util.List[EpochEntry] = this.epochs.asJava
-  }
+  val mockTime = new MockTime()
+  private val checkpoint: LeaderEpochCheckpointFile = new 
LeaderEpochCheckpointFile(TestUtils.tempFile(), new LogDirFailureChannel(1))
 
-  private val cache = new LeaderEpochFileCache(tp, checkpoint)
+  private val cache = new LeaderEpochFileCache(tp, checkpoint, 
mockTime.scheduler)
 
   @Test
   def testPreviousEpoch(): Unit = {
@@ -57,7 +54,7 @@ class LeaderEpochFileCacheTest {
     cache.assign(10, 20)
     assertEquals(OptionalInt.of(4), cache.previousEpoch)
 
-    cache.truncateFromEnd(18)
+    cache.truncateFromEndAsyncFlush(18)
     assertEquals(OptionalInt.of(2), cache.previousEpoch)
   }
 
@@ -245,12 +242,12 @@ class LeaderEpochFileCacheTest {
     val checkpoint = new LeaderEpochCheckpointFile(new File(checkpointPath), 
new LogDirFailureChannel(1))
 
     //Given
-    val cache = new LeaderEpochFileCache(tp, checkpoint)
+    val cache = new LeaderEpochFileCache(tp, checkpoint, new 
MockTime().scheduler)
     cache.assign(2, 6)
 
     //When
     val checkpoint2 = new LeaderEpochCheckpointFile(new File(checkpointPath), 
new LogDirFailureChannel(1))
-    val cache2 = new LeaderEpochFileCache(tp, checkpoint2)
+    val cache2 = new LeaderEpochFileCache(tp, checkpoint2, new 
MockTime().scheduler)
 
     //Then
     assertEquals(1, cache2.epochEntries.size)
@@ -387,7 +384,7 @@ class LeaderEpochFileCacheTest {
     cache.assign(4, 11)
 
     //When clear latest on epoch boundary
-    cache.truncateFromEnd(8)
+    cache.truncateFromEndAsyncFlush(8)
 
     //Then should remove two latest epochs (remove is inclusive)
     assertEquals(java.util.Arrays.asList(new EpochEntry(2, 6)), 
cache.epochEntries)
@@ -401,7 +398,7 @@ class LeaderEpochFileCacheTest {
     cache.assign(4, 11)
 
     //When reset to offset ON epoch boundary
-    cache.truncateFromStart(8)
+    cache.truncateFromStartAsyncFlush(8)
 
     //Then should preserve (3, 8)
     assertEquals(java.util.Arrays.asList(new EpochEntry(3, 8), new 
EpochEntry(4, 11)), cache.epochEntries)
@@ -415,7 +412,7 @@ class LeaderEpochFileCacheTest {
     cache.assign(4, 11)
 
     //When reset to offset BETWEEN epoch boundaries
-    cache.truncateFromStart(9)
+    cache.truncateFromStartAsyncFlush(9)
 
     //Then we should retain epoch 3, but update it's offset to 9 as 8 has been 
removed
     assertEquals(java.util.Arrays.asList(new EpochEntry(3, 9), new 
EpochEntry(4, 11)), cache.epochEntries)
@@ -429,7 +426,7 @@ class LeaderEpochFileCacheTest {
     cache.assign(4, 11)
 
     //When reset to offset before first epoch offset
-    cache.truncateFromStart(1)
+    cache.truncateFromStartAsyncFlush(1)
 
     //Then nothing should change
     assertEquals(java.util.Arrays.asList(new EpochEntry(2, 6),new 
EpochEntry(3, 8), new EpochEntry(4, 11)), cache.epochEntries)
@@ -443,7 +440,7 @@ class LeaderEpochFileCacheTest {
     cache.assign(4, 11)
 
     //When reset to offset on earliest epoch boundary
-    cache.truncateFromStart(6)
+    cache.truncateFromStartAsyncFlush(6)
 
     //Then nothing should change
     assertEquals(java.util.Arrays.asList(new EpochEntry(2, 6),new 
EpochEntry(3, 8), new EpochEntry(4, 11)), cache.epochEntries)
@@ -457,7 +454,7 @@ class LeaderEpochFileCacheTest {
     cache.assign(4, 11)
 
     //When
-    cache.truncateFromStart(11)
+    cache.truncateFromStartAsyncFlush(11)
 
     //Then retain the last
     assertEquals(Collections.singletonList(new EpochEntry(4, 11)), 
cache.epochEntries)
@@ -471,7 +468,7 @@ class LeaderEpochFileCacheTest {
     cache.assign(4, 11)
 
     //When we clear from a position between offset 8 & offset 11
-    cache.truncateFromStart(9)
+    cache.truncateFromStartAsyncFlush(9)
 
     //Then we should update the middle epoch entry's offset
     assertEquals(java.util.Arrays.asList(new EpochEntry(3, 9), new 
EpochEntry(4, 11)), cache.epochEntries)
@@ -485,7 +482,7 @@ class LeaderEpochFileCacheTest {
     cache.assign(2, 10)
 
     //When we clear from a position between offset 0 & offset 7
-    cache.truncateFromStart(5)
+    cache.truncateFromStartAsyncFlush(5)
 
     //Then we should keep epoch 0 but update the offset appropriately
     assertEquals(java.util.Arrays.asList(new EpochEntry(0,5), new 
EpochEntry(1, 7), new EpochEntry(2, 10)),
@@ -500,7 +497,7 @@ class LeaderEpochFileCacheTest {
     cache.assign(4, 11)
 
     //When reset to offset beyond last epoch
-    cache.truncateFromStart(15)
+    cache.truncateFromStartAsyncFlush(15)
 
     //Then update the last
     assertEquals(Collections.singletonList(new EpochEntry(4, 15)), 
cache.epochEntries)
@@ -514,7 +511,7 @@ class LeaderEpochFileCacheTest {
     cache.assign(4, 11)
 
     //When reset to offset BETWEEN epoch boundaries
-    cache.truncateFromEnd( 9)
+    cache.truncateFromEndAsyncFlush( 9)
 
     //Then should keep the preceding epochs
     assertEquals(OptionalInt.of(3), cache.latestEpoch)
@@ -543,7 +540,7 @@ class LeaderEpochFileCacheTest {
     cache.assign(4, 11)
 
     //When reset to offset on epoch boundary
-    cache.truncateFromStart(UNDEFINED_EPOCH_OFFSET)
+    cache.truncateFromStartAsyncFlush(UNDEFINED_EPOCH_OFFSET)
 
     //Then should do nothing
     assertEquals(3, cache.epochEntries.size)
@@ -557,7 +554,7 @@ class LeaderEpochFileCacheTest {
     cache.assign(4, 11)
 
     //When reset to offset on epoch boundary
-    cache.truncateFromEnd(UNDEFINED_EPOCH_OFFSET)
+    cache.truncateFromEndAsyncFlush(UNDEFINED_EPOCH_OFFSET)
 
     //Then should do nothing
     assertEquals(3, cache.epochEntries.size)
@@ -578,13 +575,13 @@ class LeaderEpochFileCacheTest {
   @Test
   def shouldClearEarliestOnEmptyCache(): Unit = {
     //Then
-    cache.truncateFromStart(7)
+    cache.truncateFromStartAsyncFlush(7)
   }
 
   @Test
   def shouldClearLatestOnEmptyCache(): Unit = {
     //Then
-    cache.truncateFromEnd(7)
+    cache.truncateFromEndAsyncFlush(7)
   }
 
   @Test
@@ -600,7 +597,7 @@ class LeaderEpochFileCacheTest {
     cache.assign(10, 20)
     assertEquals(OptionalInt.of(4), cache.previousEpoch(10))
 
-    cache.truncateFromEnd(18)
+    cache.truncateFromEndAsyncFlush(18)
     assertEquals(OptionalInt.of(2), 
cache.previousEpoch(cache.latestEpoch.getAsInt))
   }
 
@@ -617,7 +614,7 @@ class LeaderEpochFileCacheTest {
     cache.assign(10, 20)
     assertEquals(Optional.of(new EpochEntry(4, 15)), cache.previousEntry(10))
 
-    cache.truncateFromEnd(18)
+    cache.truncateFromEndAsyncFlush(18)
     assertEquals(Optional.of(new EpochEntry(2, 10)), 
cache.previousEntry(cache.latestEpoch.getAsInt))
   }
 
@@ -658,4 +655,15 @@ class LeaderEpochFileCacheTest {
     assertEquals(OptionalInt.empty(), cache.epochForOffset(5))
   }
 
+  @Test
+  def shouldWriteCheckpointOnTruncation(): Unit = {
+    cache.assign(2, 6)
+    cache.assign(3, 8)
+    cache.assign(4, 11)
+
+    cache.truncateFromEndAsyncFlush(11)
+    cache.truncateFromStartAsyncFlush(8)
+
+    assertEquals(List(new EpochEntry(3, 8)).asJava, checkpoint.read())
+  }
 }
diff --git a/core/src/test/scala/unit/kafka/utils/SchedulerTest.scala 
b/core/src/test/scala/unit/kafka/utils/SchedulerTest.scala
index d25fdb1b4e9..6280318af5d 100644
--- a/core/src/test/scala/unit/kafka/utils/SchedulerTest.scala
+++ b/core/src/test/scala/unit/kafka/utils/SchedulerTest.scala
@@ -139,7 +139,8 @@ class SchedulerTest {
     val topicPartition = UnifiedLog.parseTopicPartitionName(logDir)
     val logDirFailureChannel = new LogDirFailureChannel(10)
     val segments = new LogSegments(topicPartition)
-    val leaderEpochCache = UnifiedLog.maybeCreateLeaderEpochCache(logDir, 
topicPartition, logDirFailureChannel, logConfig.recordVersion, "")
+    val leaderEpochCache = UnifiedLog.maybeCreateLeaderEpochCache(
+      logDir, topicPartition, logDirFailureChannel, logConfig.recordVersion, 
"", None, mockTime.scheduler)
     val producerStateManager = new ProducerStateManager(topicPartition, logDir,
       maxTransactionTimeoutMs, new 
ProducerStateManagerConfig(maxProducerIdExpirationMs, false), mockTime)
     val offsets = new LogLoader(
diff --git 
a/server-common/src/main/java/org/apache/kafka/server/common/CheckpointFile.java
 
b/server-common/src/main/java/org/apache/kafka/server/common/CheckpointFile.java
index 9c115881328..6efbaa136e0 100644
--- 
a/server-common/src/main/java/org/apache/kafka/server/common/CheckpointFile.java
+++ 
b/server-common/src/main/java/org/apache/kafka/server/common/CheckpointFile.java
@@ -72,7 +72,7 @@ public class CheckpointFile<T> {
         tempPath = Paths.get(absolutePath + ".tmp");
     }
 
-    public void write(Collection<T> entries, boolean sync) throws IOException {
+    public void write(Collection<T> entries) throws IOException {
         synchronized (lock) {
             // write to temp file and then swap with the existing file
             try (FileOutputStream fileOutputStream = new 
FileOutputStream(tempPath.toFile());
@@ -80,12 +80,10 @@ public class CheckpointFile<T> {
                 CheckpointWriteBuffer<T> checkpointWriteBuffer = new 
CheckpointWriteBuffer<>(writer, version, formatter);
                 checkpointWriteBuffer.write(entries);
                 writer.flush();
-                if (sync) {
-                    fileOutputStream.getFD().sync();
-                }
+                fileOutputStream.getFD().sync();
             }
 
-            Utils.atomicMoveWithFallback(tempPath, absolutePath, sync);
+            Utils.atomicMoveWithFallback(tempPath, absolutePath);
         }
     }
 
diff --git 
a/storage/src/main/java/org/apache/kafka/storage/internals/checkpoint/CheckpointFileWithFailureHandler.java
 
b/storage/src/main/java/org/apache/kafka/storage/internals/checkpoint/CheckpointFileWithFailureHandler.java
index 35abfb5a984..79963d79d23 100644
--- 
a/storage/src/main/java/org/apache/kafka/storage/internals/checkpoint/CheckpointFileWithFailureHandler.java
+++ 
b/storage/src/main/java/org/apache/kafka/storage/internals/checkpoint/CheckpointFileWithFailureHandler.java
@@ -19,13 +19,19 @@ package org.apache.kafka.storage.internals.checkpoint;
 import org.apache.kafka.common.errors.KafkaStorageException;
 import org.apache.kafka.server.common.CheckpointFile;
 import org.apache.kafka.storage.internals.log.LogDirFailureChannel;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
 
 import java.io.File;
+import java.io.FileNotFoundException;
 import java.io.IOException;
+import java.nio.file.NoSuchFileException;
 import java.util.Collection;
 import java.util.List;
 
 public class CheckpointFileWithFailureHandler<T> {
+    private static final Logger log = 
LoggerFactory.getLogger(CheckpointFileWithFailureHandler.class);
+
 
     public final File file;
     private final LogDirFailureChannel logDirFailureChannel;
@@ -41,9 +47,21 @@ public class CheckpointFileWithFailureHandler<T> {
         checkpointFile = new CheckpointFile<>(file, version, formatter);
     }
 
-    public void write(Collection<T> entries, boolean sync) {
+    public void write(Collection<T> entries) {
+        try {
+            checkpointFile.write(entries);
+        } catch (IOException e) {
+            String msg = "Error while writing to checkpoint file " + 
file.getAbsolutePath();
+            logDirFailureChannel.maybeAddOfflineLogDir(logDir, msg, e);
+            throw new KafkaStorageException(msg, e);
+        }
+    }
+
+    public void writeIfDirExists(Collection<T> entries) {
         try {
-            checkpointFile.write(entries, sync);
+            checkpointFile.write(entries);
+        } catch (FileNotFoundException | NoSuchFileException e) {
+            log.warn("Failed to write to checkpoint file {}. This is ok if the 
topic/partition is being deleted", file.getAbsolutePath(), e);
         } catch (IOException e) {
             String msg = "Error while writing to checkpoint file " + 
file.getAbsolutePath();
             logDirFailureChannel.maybeAddOfflineLogDir(logDir, msg, e);
diff --git 
a/storage/src/main/java/org/apache/kafka/storage/internals/checkpoint/InMemoryLeaderEpochCheckpoint.java
 
b/storage/src/main/java/org/apache/kafka/storage/internals/checkpoint/InMemoryLeaderEpochCheckpoint.java
deleted file mode 100644
index 499c19fb78b..00000000000
--- 
a/storage/src/main/java/org/apache/kafka/storage/internals/checkpoint/InMemoryLeaderEpochCheckpoint.java
+++ /dev/null
@@ -1,63 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- *    http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-package org.apache.kafka.storage.internals.checkpoint;
-
-import org.apache.kafka.server.common.CheckpointFile;
-import org.apache.kafka.storage.internals.log.EpochEntry;
-
-import java.io.BufferedWriter;
-import java.io.ByteArrayOutputStream;
-import java.io.IOException;
-import java.io.OutputStreamWriter;
-import java.nio.ByteBuffer;
-import java.nio.charset.StandardCharsets;
-import java.util.ArrayList;
-import java.util.Collection;
-import java.util.Collections;
-import java.util.List;
-
-/**
- * This class stores a list of EpochEntry(LeaderEpoch + Offsets) to memory
- *
- * The motivation for this class is to allow remote log manager to create the 
RemoteLogSegmentMetadata(RLSM)
- * with the correct leader epoch info for a specific segment. To do that, we 
need to rely on the LeaderEpochCheckpointCache
- * to truncate from start and end, to get the epoch info. However, we don't 
really want to truncate the epochs in cache
- * (and write to checkpoint file in the end). So, we introduce this 
InMemoryLeaderEpochCheckpoint to feed into LeaderEpochCheckpointCache,
- * and when we truncate the epoch for RLSM, we can do them in memory without 
affecting the checkpoint file, and without interacting with file system.
- */
-public class InMemoryLeaderEpochCheckpoint implements LeaderEpochCheckpoint {
-    private List<EpochEntry> epochs = Collections.emptyList();
-
-    public void write(Collection<EpochEntry> epochs, boolean ignored) {
-        this.epochs = new ArrayList<>(epochs);
-    }
-
-    public List<EpochEntry> read() {
-        return Collections.unmodifiableList(epochs);
-    }
-
-    public ByteBuffer readAsByteBuffer() throws IOException {
-        ByteArrayOutputStream stream = new ByteArrayOutputStream();
-        try (BufferedWriter writer = new BufferedWriter(new 
OutputStreamWriter(stream, StandardCharsets.UTF_8))) {
-            CheckpointFile.CheckpointWriteBuffer<EpochEntry> writeBuffer = new 
CheckpointFile.CheckpointWriteBuffer<>(writer, 0, 
LeaderEpochCheckpointFile.FORMATTER);
-            writeBuffer.write(epochs);
-            writer.flush();
-        }
-
-        return ByteBuffer.wrap(stream.toByteArray());
-    }
-}
diff --git 
a/storage/src/main/java/org/apache/kafka/storage/internals/checkpoint/LeaderEpochCheckpoint.java
 
b/storage/src/main/java/org/apache/kafka/storage/internals/checkpoint/LeaderEpochCheckpoint.java
deleted file mode 100644
index 28ffae03df0..00000000000
--- 
a/storage/src/main/java/org/apache/kafka/storage/internals/checkpoint/LeaderEpochCheckpoint.java
+++ /dev/null
@@ -1,34 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- *    http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-package org.apache.kafka.storage.internals.checkpoint;
-
-import org.apache.kafka.storage.internals.log.EpochEntry;
-
-import java.util.Collection;
-import java.util.List;
-
-public interface LeaderEpochCheckpoint {
-    // in file-backed checkpoint implementation, the content should be
-    // synced to the device if `sync` is true
-    void write(Collection<EpochEntry> epochs, boolean sync);
-
-    default void write(Collection<EpochEntry> epochs) {
-        write(epochs, true);
-    }
-
-    List<EpochEntry> read();
-}
diff --git 
a/storage/src/main/java/org/apache/kafka/storage/internals/checkpoint/LeaderEpochCheckpointFile.java
 
b/storage/src/main/java/org/apache/kafka/storage/internals/checkpoint/LeaderEpochCheckpointFile.java
index 3472182aeea..392a3653340 100644
--- 
a/storage/src/main/java/org/apache/kafka/storage/internals/checkpoint/LeaderEpochCheckpointFile.java
+++ 
b/storage/src/main/java/org/apache/kafka/storage/internals/checkpoint/LeaderEpochCheckpointFile.java
@@ -38,7 +38,7 @@ import java.util.regex.Pattern;
  * 1  2
  * -----checkpoint file end----------
  */
-public class LeaderEpochCheckpointFile implements LeaderEpochCheckpoint {
+public class LeaderEpochCheckpointFile {
 
     public static final Formatter FORMATTER = new Formatter();
 
@@ -53,11 +53,14 @@ public class LeaderEpochCheckpointFile implements 
LeaderEpochCheckpoint {
     }
 
     public void write(Collection<EpochEntry> epochs) {
-        write(epochs, true);
+        checkpoint.write(epochs);
     }
 
-    public void write(Collection<EpochEntry> epochs, boolean sync) {
-        checkpoint.write(epochs, sync);
+    public void writeForTruncation(Collection<EpochEntry> epochs) {
+        // Writing epoch entries after truncation is done asynchronously for 
performance reasons.
+        // This could cause NoSuchFileException when the directory is renamed 
concurrently for topic deletion,
+        // so we use writeIfDirExists here.
+        checkpoint.writeIfDirExists(epochs);
     }
 
     public List<EpochEntry> read() {
diff --git 
a/storage/src/main/java/org/apache/kafka/storage/internals/epoch/LeaderEpochFileCache.java
 
b/storage/src/main/java/org/apache/kafka/storage/internals/epoch/LeaderEpochFileCache.java
index 03df6cc0dce..7b78a70993d 100644
--- 
a/storage/src/main/java/org/apache/kafka/storage/internals/epoch/LeaderEpochFileCache.java
+++ 
b/storage/src/main/java/org/apache/kafka/storage/internals/epoch/LeaderEpochFileCache.java
@@ -18,15 +18,18 @@ package org.apache.kafka.storage.internals.epoch;
 
 import org.apache.kafka.common.TopicPartition;
 import org.apache.kafka.common.utils.LogContext;
+import org.apache.kafka.server.util.Scheduler;
+import org.apache.kafka.storage.internals.checkpoint.LeaderEpochCheckpointFile;
 import org.apache.kafka.storage.internals.log.EpochEntry;
-import org.apache.kafka.storage.internals.checkpoint.LeaderEpochCheckpoint;
 import org.slf4j.Logger;
 
 import java.util.AbstractMap;
 import java.util.ArrayList;
+import java.util.Collections;
 import java.util.Iterator;
 import java.util.List;
 import java.util.Map;
+import java.util.Map.Entry;
 import java.util.NavigableMap;
 import java.util.Optional;
 import java.util.OptionalInt;
@@ -42,10 +45,15 @@ import static 
org.apache.kafka.common.requests.OffsetsForLeaderEpochResponse.UND
  * <p>
  * Leader Epoch = epoch assigned to each leader by the controller.
  * Offset = offset of the first message in each epoch.
+ * <p>
+ * Note that {@link #truncateFromStartAsyncFlush},{@link 
#truncateFromEndAsyncFlush} flush the epoch-entry changes to checkpoint 
asynchronously.
+ * Hence, it is instantiater's responsibility to ensure restoring the cache to 
the correct state after instantiating
+ * this class from checkpoint (which might contain stale epoch entries right 
after instantiation).
  */
 public class LeaderEpochFileCache {
     private final TopicPartition topicPartition;
-    private final LeaderEpochCheckpoint checkpoint;
+    private final LeaderEpochCheckpointFile checkpoint;
+    private final Scheduler scheduler;
     private final Logger log;
 
     private final ReentrantReadWriteLock lock = new ReentrantReadWriteLock();
@@ -55,16 +63,40 @@ public class LeaderEpochFileCache {
     /**
      * @param topicPartition the associated topic partition
      * @param checkpoint     the checkpoint file
+     * @param scheduler      the scheduler to use for async I/O operations
      */
     @SuppressWarnings("this-escape")
-    public LeaderEpochFileCache(TopicPartition topicPartition, 
LeaderEpochCheckpoint checkpoint) {
+    public LeaderEpochFileCache(TopicPartition topicPartition, 
LeaderEpochCheckpointFile checkpoint, Scheduler scheduler) {
         this.checkpoint = checkpoint;
         this.topicPartition = topicPartition;
+        this.scheduler = scheduler;
         LogContext logContext = new LogContext("[LeaderEpochCache " + 
topicPartition + "] ");
         log = logContext.logger(LeaderEpochFileCache.class);
         checkpoint.read().forEach(this::assign);
     }
 
+    /**
+     * Instantiate a new LeaderEpochFileCache with provided epoch entries 
instead of from the backing checkpoint file.
+     * The provided epoch entries are expected to be no less fresh than the 
checkpoint file.
+     * @param epochEntries the current epoch entries
+     * @param topicPartition the associated topic partition
+     * @param checkpoint the checkpoint file
+     * @param scheduler the scheduler to use for async I/O operations
+     */
+    private LeaderEpochFileCache(List<EpochEntry> epochEntries,
+                                 TopicPartition topicPartition,
+                                 LeaderEpochCheckpointFile checkpoint,
+                                 Scheduler scheduler) {
+        this.checkpoint = checkpoint;
+        this.topicPartition = topicPartition;
+        this.scheduler = scheduler;
+        LogContext logContext = new LogContext("[LeaderEpochCache " + 
topicPartition + "] ");
+        log = logContext.logger(LeaderEpochFileCache.class);
+        for (EpochEntry entry : epochEntries) {
+            epochs.put(entry.epoch, entry);
+        }
+    }
+
     /**
      * Assigns the supplied Leader Epoch to the supplied Offset
      * Once the epoch is assigned it cannot be reassigned
@@ -73,7 +105,7 @@ public class LeaderEpochFileCache {
         EpochEntry entry = new EpochEntry(epoch, startOffset);
         if (assign(entry)) {
             log.debug("Appended new epoch entry {}. Cache now contains {} 
entries.", entry, epochs.size());
-            writeToFile(true);
+            writeToFile();
         }
     }
 
@@ -83,7 +115,7 @@ public class LeaderEpochFileCache {
                 log.debug("Appended new epoch entry {}. Cache now contains {} 
entries.", entry, epochs.size());
             }
         });
-        if (!entries.isEmpty()) writeToFile(true);
+        if (!entries.isEmpty()) writeToFile();
     }
 
     private boolean isUpdateNeeded(EpochEntry entry) {
@@ -117,7 +149,9 @@ public class LeaderEpochFileCache {
      * Remove any entries which violate monotonicity prior to appending a new 
entry
      */
     private void maybeTruncateNonMonotonicEntries(EpochEntry newEntry) {
-        List<EpochEntry> removedEpochs = removeFromEnd(entry -> entry.epoch >= 
newEntry.epoch || entry.startOffset >= newEntry.startOffset);
+        List<EpochEntry> removedEpochs = removeWhileMatching(
+                epochs.descendingMap().entrySet().iterator(),
+                entry -> entry.epoch >= newEntry.epoch || entry.startOffset >= 
newEntry.startOffset);
 
         if (removedEpochs.size() > 1 || (!removedEpochs.isEmpty() && 
removedEpochs.get(0).startOffset != newEntry.startOffset)) {
 
@@ -128,15 +162,7 @@ public class LeaderEpochFileCache {
         }
     }
 
-    private List<EpochEntry> removeFromEnd(Predicate<EpochEntry> predicate) {
-        return 
removeWhileMatching(epochs.descendingMap().entrySet().iterator(), predicate);
-    }
-
-    private List<EpochEntry> removeFromStart(Predicate<EpochEntry> predicate) {
-        return removeWhileMatching(epochs.entrySet().iterator(), predicate);
-    }
-
-    private List<EpochEntry> removeWhileMatching(Iterator<Map.Entry<Integer, 
EpochEntry>> iterator, Predicate<EpochEntry> predicate) {
+    private static List<EpochEntry> 
removeWhileMatching(Iterator<Map.Entry<Integer, EpochEntry>> iterator, 
Predicate<EpochEntry> predicate) {
         ArrayList<EpochEntry> removedEpochs = new ArrayList<>();
 
         while (iterator.hasNext()) {
@@ -305,22 +331,23 @@ public class LeaderEpochFileCache {
 
     /**
      * Removes all epoch entries from the store with start offsets greater 
than or equal to the passed offset.
+     * <p>
+     * Checkpoint-flushing is done asynchronously.
      */
-    public void truncateFromEnd(long endOffset) {
+    public void truncateFromEndAsyncFlush(long endOffset) {
         lock.writeLock().lock();
         try {
-            Optional<EpochEntry> epochEntry = latestEntry();
-            if (endOffset >= 0 && epochEntry.isPresent() && 
epochEntry.get().startOffset >= endOffset) {
-                List<EpochEntry> removedEntries = removeFromEnd(x -> 
x.startOffset >= endOffset);
-
-                // We intentionally don't force flushing change to the device 
here because:
+            List<EpochEntry> removedEntries = truncateFromEnd(epochs, 
endOffset);
+            if (!removedEntries.isEmpty()) {
+                // We flush the change to the device in the background because:
                 // - To avoid fsync latency
                 //   * fsync latency could be huge on a disk glitch, which is 
not rare in spinning drives
                 //   * This method is called by ReplicaFetcher threads, which 
could block replica fetching
                 //     then causing ISR shrink or high produce response time 
degradation in remote scope on high fsync latency.
-                // - Even when stale epochs remained in LeaderEpoch file due 
to the unclean shutdown, it will be handled by
-                //   another truncateFromEnd call on log loading procedure so 
it won't be a problem
-                writeToFile(false);
+                // - We still flush the change in #assign synchronously, 
meaning that it's guaranteed that the checkpoint file always has no missing 
entries.
+                //   * Even when stale epochs are restored from the checkpoint 
file after the unclean shutdown, it will be handled by
+                //     another truncateFromEnd call on log loading procedure, 
so it won't be a problem
+                scheduler.scheduleOnce("leader-epoch-cache-flush-" + 
topicPartition, this::writeToFileForTruncation);
 
                 log.debug("Cleared entries {} from epoch cache after 
truncating to end offset {}, leaving {} entries in the cache.", removedEntries, 
endOffset, epochs.size());
             }
@@ -334,28 +361,27 @@ public class LeaderEpochFileCache {
      * be offset, then clears any previous epoch entries.
      * <p>
      * This method is exclusive: so truncateFromStart(6) will retain an entry 
at offset 6.
+     * <p>
+     * Checkpoint-flushing is done asynchronously.
      *
      * @param startOffset the offset to clear up to
      */
-    public void truncateFromStart(long startOffset) {
+    public void truncateFromStartAsyncFlush(long startOffset) {
         lock.writeLock().lock();
         try {
-            List<EpochEntry> removedEntries = removeFromStart(entry -> 
entry.startOffset <= startOffset);
-
+            List<EpochEntry> removedEntries = truncateFromStart(epochs, 
startOffset);
             if (!removedEntries.isEmpty()) {
-                EpochEntry firstBeforeStartOffset = 
removedEntries.get(removedEntries.size() - 1);
-                EpochEntry updatedFirstEntry = new 
EpochEntry(firstBeforeStartOffset.epoch, startOffset);
-                epochs.put(updatedFirstEntry.epoch, updatedFirstEntry);
-
-                // We intentionally don't force flushing change to the device 
here because:
+                // We flush the change to the device in the background because:
                 // - To avoid fsync latency
                 //   * fsync latency could be huge on a disk glitch, which is 
not rare in spinning drives
                 //   * This method is called as part of deleteRecords with 
holding UnifiedLog#lock.
                 //      - Meanwhile all produces against the partition will be 
blocked, which causes req-handlers to exhaust
-                // - Even when stale epochs remained in LeaderEpoch file due 
to the unclean shutdown, it will be recovered by
-                //   another truncateFromStart call on log loading procedure 
so it won't be a problem
-                writeToFile(false);
+                // - We still flush the change in #assign synchronously, 
meaning that it's guaranteed that the checkpoint file always has no missing 
entries.
+                //   * Even when stale epochs are restored from the checkpoint 
file after the unclean shutdown, it will be handled by
+                //     another truncateFromStart call on log loading 
procedure, so it won't be a problem
+                scheduler.scheduleOnce("leader-epoch-cache-flush-" + 
topicPartition, this::writeToFileForTruncation);
 
+                EpochEntry updatedFirstEntry = 
removedEntries.get(removedEntries.size() - 1);
                 log.debug("Cleared entries {} and rewrote first entry {} after 
truncating to start offset {}, leaving {} in the cache.", removedEntries, 
updatedFirstEntry, startOffset, epochs.size());
             }
         } finally {
@@ -363,6 +389,27 @@ public class LeaderEpochFileCache {
         }
     }
 
+    private static List<EpochEntry> truncateFromStart(TreeMap<Integer, 
EpochEntry> epochs, long startOffset) {
+        List<EpochEntry> removedEntries = removeWhileMatching(
+                epochs.entrySet().iterator(), entry -> entry.startOffset <= 
startOffset);
+
+        if (!removedEntries.isEmpty()) {
+            EpochEntry firstBeforeStartOffset = 
removedEntries.get(removedEntries.size() - 1);
+            EpochEntry updatedFirstEntry = new 
EpochEntry(firstBeforeStartOffset.epoch, startOffset);
+            epochs.put(updatedFirstEntry.epoch, updatedFirstEntry);
+        }
+
+        return removedEntries;
+    }
+
+    private static List<EpochEntry> truncateFromEnd(TreeMap<Integer, 
EpochEntry> epochs, long endOffset) {
+        Optional<EpochEntry> epochEntry = 
Optional.ofNullable(epochs.lastEntry()).map(Entry::getValue);
+        if (endOffset >= 0 && epochEntry.isPresent() && 
epochEntry.get().startOffset >= endOffset) {
+            return 
removeWhileMatching(epochs.descendingMap().entrySet().iterator(), x -> 
x.startOffset >= endOffset);
+        }
+        return Collections.emptyList();
+    }
+
     public OptionalInt epochForOffset(long offset) {
         lock.readLock().lock();
         try {
@@ -386,11 +433,39 @@ public class LeaderEpochFileCache {
         }
     }
 
-    public LeaderEpochFileCache writeTo(LeaderEpochCheckpoint 
leaderEpochCheckpoint) {
+    /**
+     * Returns a new LeaderEpochFileCache which contains same
+     * epoch entries with replacing backing checkpoint file.
+     * @param leaderEpochCheckpoint the new checkpoint file
+     * @return a new LeaderEpochFileCache instance
+     */
+    public LeaderEpochFileCache withCheckpoint(LeaderEpochCheckpointFile 
leaderEpochCheckpoint) {
+        lock.readLock().lock();
+        try {
+            return new LeaderEpochFileCache(epochEntries(),
+                                            topicPartition,
+                                            leaderEpochCheckpoint,
+                                            scheduler);
+        } finally {
+            lock.readLock().unlock();
+        }
+    }
+
+    /**
+     * Returns the leader epoch entries within the range of the given start 
and end offset
+     * @param startOffset The start offset of the epoch entries (inclusive).
+     * @param endOffset   The end offset of the epoch entries (exclusive)
+     * @return the leader epoch entries
+     */
+    public List<EpochEntry> epochEntriesInRange(long startOffset, long 
endOffset) {
         lock.readLock().lock();
         try {
-            leaderEpochCheckpoint.write(epochEntries());
-            return new LeaderEpochFileCache(topicPartition, 
leaderEpochCheckpoint);
+            TreeMap<Integer, EpochEntry> epochsCopy = new 
TreeMap<>(this.epochs);
+            if (startOffset >= 0) {
+                truncateFromStart(epochsCopy, startOffset);
+            }
+            truncateFromEnd(epochsCopy, endOffset);
+            return new ArrayList<>(epochsCopy.values());
         } finally {
             lock.readLock().unlock();
         }
@@ -403,7 +478,7 @@ public class LeaderEpochFileCache {
         lock.writeLock().lock();
         try {
             epochs.clear();
-            writeToFile(true);
+            writeToFile();
         } finally {
             lock.writeLock().unlock();
         }
@@ -440,12 +515,23 @@ public class LeaderEpochFileCache {
         }
     }
 
-    private void writeToFile(boolean sync) {
+    private void writeToFile() {
+        lock.readLock().lock();
+        try {
+            checkpoint.write(epochs.values());
+        } finally {
+            lock.readLock().unlock();
+        }
+    }
+
+    private void writeToFileForTruncation() {
+        List<EpochEntry> entries;
         lock.readLock().lock();
         try {
-            checkpoint.write(epochs.values(), sync);
+            entries = new ArrayList<>(epochs.values());
         } finally {
             lock.readLock().unlock();
         }
+        checkpoint.writeForTruncation(entries);
     }
 }


Reply via email to