This is an automated email from the ASF dual-hosted git repository.

jgus pushed a commit to branch 3.0
in repository https://gitbox.apache.org/repos/asf/kafka.git


The following commit(s) were added to refs/heads/3.0 by this push:
     new cd233e9  KAFKA-13099; Transactional expiration should account for max 
batch size (#11098)
cd233e9 is described below

commit cd233e92b0a17797f3d46b119cb8cf18d049be62
Author: Jason Gustafson <[email protected]>
AuthorDate: Tue Jul 27 18:23:00 2021 -0700

    KAFKA-13099; Transactional expiration should account for max batch size 
(#11098)
    
    When expiring transactionalIds, we group the tombstones together into 
batches. Currently there is no limit on the size of these batches, which can 
lead to `MESSAGE_TOO_LARGE` errors when a bunch of transactionalIds need to be 
expired at the same time. This patch fixes the problem by ensuring that the 
batch size respects the configured limit. Any transactionalIds which are 
eligible for expiration and cannot be fit into the batch are postponed until 
the next periodic check.
    
    Reviewers: David Jacot <[email protected]>, Guozhang Wang 
<[email protected]>
---
 .../apache/kafka/common/record/MemoryRecords.java  |  14 ++
 .../transaction/TransactionMetadata.scala          |   5 +
 .../transaction/TransactionStateManager.scala      | 208 +++++++++++-----
 core/src/main/scala/kafka/utils/Pool.scala         |  11 +-
 .../AbstractCoordinatorConcurrencyTest.scala       |   7 +-
 .../TransactionCoordinatorConcurrencyTest.scala    |  10 +-
 .../transaction/TransactionStateManagerTest.scala  | 262 ++++++++++++++++++---
 7 files changed, 418 insertions(+), 99 deletions(-)

diff --git 
a/clients/src/main/java/org/apache/kafka/common/record/MemoryRecords.java 
b/clients/src/main/java/org/apache/kafka/common/record/MemoryRecords.java
index 1991759..b631171 100644
--- a/clients/src/main/java/org/apache/kafka/common/record/MemoryRecords.java
+++ b/clients/src/main/java/org/apache/kafka/common/record/MemoryRecords.java
@@ -408,6 +408,20 @@ public class MemoryRecords extends AbstractRecords {
         return builder(buffer, RecordBatch.CURRENT_MAGIC_VALUE, 
compressionType, timestampType, baseOffset);
     }
 
+    public static MemoryRecordsBuilder builder(ByteBuffer buffer,
+                                               CompressionType compressionType,
+                                               TimestampType timestampType,
+                                               long baseOffset,
+                                               int maxSize) {
+        long logAppendTime = RecordBatch.NO_TIMESTAMP;
+        if (timestampType == TimestampType.LOG_APPEND_TIME)
+            logAppendTime = System.currentTimeMillis();
+
+        return new MemoryRecordsBuilder(buffer, 
RecordBatch.CURRENT_MAGIC_VALUE, compressionType, timestampType, baseOffset,
+            logAppendTime, RecordBatch.NO_PRODUCER_ID, 
RecordBatch.NO_PRODUCER_EPOCH, RecordBatch.NO_SEQUENCE,
+            false, false, RecordBatch.NO_PARTITION_LEADER_EPOCH, maxSize);
+    }
+
     public static MemoryRecordsBuilder idempotentBuilder(ByteBuffer buffer,
                                                          CompressionType 
compressionType,
                                                          long baseOffset,
diff --git 
a/core/src/main/scala/kafka/coordinator/transaction/TransactionMetadata.scala 
b/core/src/main/scala/kafka/coordinator/transaction/TransactionMetadata.scala
index 1027468..0f6d4b7 100644
--- 
a/core/src/main/scala/kafka/coordinator/transaction/TransactionMetadata.scala
+++ 
b/core/src/main/scala/kafka/coordinator/transaction/TransactionMetadata.scala
@@ -66,6 +66,8 @@ private[transaction] sealed trait TransactionState {
   def name: String
 
   def validPreviousStates: Set[TransactionState]
+
+  def isExpirationAllowed: Boolean = false
 }
 
 /**
@@ -78,6 +80,7 @@ private[transaction] case object Empty extends 
TransactionState {
   val id: Byte = 0
   val name: String = "Empty"
   val validPreviousStates: Set[TransactionState] = Set(Empty, CompleteCommit, 
CompleteAbort)
+  override def isExpirationAllowed: Boolean = true
 }
 
 /**
@@ -125,6 +128,7 @@ private[transaction] case object CompleteCommit extends 
TransactionState {
   val id: Byte = 4
   val name: String = "CompleteCommit"
   val validPreviousStates: Set[TransactionState] = Set(PrepareCommit)
+  override def isExpirationAllowed: Boolean = true
 }
 
 /**
@@ -136,6 +140,7 @@ private[transaction] case object CompleteAbort extends 
TransactionState {
   val id: Byte = 5
   val name: String = "CompleteAbort"
   val validPreviousStates: Set[TransactionState] = Set(PrepareAbort)
+  override def isExpirationAllowed: Boolean = true
 }
 
 /**
diff --git 
a/core/src/main/scala/kafka/coordinator/transaction/TransactionStateManager.scala
 
b/core/src/main/scala/kafka/coordinator/transaction/TransactionStateManager.scala
index 25580f2..217b383 100644
--- 
a/core/src/main/scala/kafka/coordinator/transaction/TransactionStateManager.scala
+++ 
b/core/src/main/scala/kafka/coordinator/transaction/TransactionStateManager.scala
@@ -21,6 +21,7 @@ import java.util.Properties
 import java.util.concurrent.TimeUnit
 import java.util.concurrent.atomic.AtomicBoolean
 import java.util.concurrent.locks.ReentrantReadWriteLock
+
 import kafka.log.{AppendOrigin, LogConfig}
 import kafka.message.UncompressedCodec
 import kafka.server.{Defaults, FetchLogEnd, ReplicaManager, RequestLocal}
@@ -32,7 +33,7 @@ import 
org.apache.kafka.common.message.ListTransactionsResponseData
 import org.apache.kafka.common.metrics.Metrics
 import org.apache.kafka.common.metrics.stats.{Avg, Max}
 import org.apache.kafka.common.protocol.Errors
-import org.apache.kafka.common.record.{FileRecords, MemoryRecords, 
SimpleRecord}
+import org.apache.kafka.common.record.{FileRecords, MemoryRecords, 
MemoryRecordsBuilder, Record, SimpleRecord, TimestampType}
 import org.apache.kafka.common.requests.ProduceResponse.PartitionResponse
 import org.apache.kafka.common.requests.TransactionResult
 import org.apache.kafka.common.utils.{Time, Utils}
@@ -140,79 +141,162 @@ class TransactionStateManager(brokerId: Int,
     }
   }
 
-  def enableTransactionalIdExpiration(): Unit = {
-    scheduler.schedule("transactionalId-expiration", () => {
-      val now = time.milliseconds()
-      inReadLock(stateLock) {
-        val transactionalIdByPartition: Map[Int, 
mutable.Iterable[TransactionalIdCoordinatorEpochAndMetadata]] =
-          transactionMetadataCache.flatMap { case (_, entry) =>
-            entry.metadataPerTransactionalId.filter { case (_, txnMetadata) => 
txnMetadata.state match {
-              case Empty | CompleteCommit | CompleteAbort => true
-              case _ => false
-            }
-            }.filter { case (_, txnMetadata) =>
-              txnMetadata.txnLastUpdateTimestamp <= now - 
config.transactionalIdExpirationMs
-            }.map { case (transactionalId, txnMetadata) =>
-              val txnMetadataTransition = txnMetadata.inLock {
-                txnMetadata.prepareDead()
+  private def removeExpiredTransactionalIds(
+    transactionPartition: TopicPartition,
+    txnMetadataCacheEntry: TxnMetadataCacheEntry,
+  ): Unit = {
+    inReadLock(stateLock) {
+      replicaManager.getLogConfig(transactionPartition) match {
+        case Some(logConfig) =>
+          val currentTimeMs = time.milliseconds()
+          val maxBatchSize = logConfig.maxMessageSize
+          val expired = 
mutable.ListBuffer.empty[TransactionalIdCoordinatorEpochAndMetadata]
+          var recordsBuilder: MemoryRecordsBuilder = null
+          val stateEntries = 
txnMetadataCacheEntry.metadataPerTransactionalId.values.iterator.buffered
+
+          def flushRecordsBuilder(): Unit = {
+            writeTombstonesForExpiredTransactionalIds(
+              transactionPartition,
+              expired.toSeq,
+              recordsBuilder.build()
+            )
+            expired.clear()
+            recordsBuilder = null
+          }
+
+          while (stateEntries.hasNext) {
+            val txnMetadata = stateEntries.head
+            val transactionalId = txnMetadata.transactionalId
+            var fullBatch = false
+
+            txnMetadata.inLock {
+              if (txnMetadata.pendingState.isEmpty && 
shouldExpire(txnMetadata, currentTimeMs)) {
+                if (recordsBuilder == null) {
+                  recordsBuilder = MemoryRecords.builder(
+                    ByteBuffer.allocate(math.min(16384, maxBatchSize)),
+                    TransactionLog.EnforcedCompressionType,
+                    TimestampType.CREATE_TIME,
+                    0L,
+                    maxBatchSize
+                  )
+                }
+
+                if (maybeAppendExpiration(txnMetadata, recordsBuilder, 
currentTimeMs)) {
+                  val transitMetadata = txnMetadata.prepareDead()
+                  expired += TransactionalIdCoordinatorEpochAndMetadata(
+                    transactionalId,
+                    txnMetadataCacheEntry.coordinatorEpoch,
+                    transitMetadata
+                  )
+                } else {
+                  fullBatch = true
+                }
               }
-              TransactionalIdCoordinatorEpochAndMetadata(transactionalId, 
entry.coordinatorEpoch, txnMetadataTransition)
             }
-          }.groupBy { transactionalIdCoordinatorEpochAndMetadata =>
-            
partitionFor(transactionalIdCoordinatorEpochAndMetadata.transactionalId)
+
+            if (fullBatch) {
+              flushRecordsBuilder()
+            } else {
+              // Advance the iterator if we do not need to retry the append
+              stateEntries.next()
+            }
           }
 
-        val recordsPerPartition = transactionalIdByPartition
-          .map { case (partition, transactionalIdCoordinatorEpochAndMetadatas) 
=>
-            val deletes: Array[SimpleRecord] = 
transactionalIdCoordinatorEpochAndMetadatas.map { entry =>
-              new SimpleRecord(now, 
TransactionLog.keyToBytes(entry.transactionalId), null)
-            }.toArray
-            val records = 
MemoryRecords.withRecords(TransactionLog.EnforcedCompressionType, deletes: _*)
-            val topicPartition = new 
TopicPartition(Topic.TRANSACTION_STATE_TOPIC_NAME, partition)
-            (topicPartition, records)
+          if (expired.nonEmpty) {
+            flushRecordsBuilder()
           }
 
-        def removeFromCacheCallback(responses: collection.Map[TopicPartition, 
PartitionResponse]): Unit = {
-          responses.forKeyValue { (topicPartition, response) =>
-            inReadLock(stateLock) {
-              val toRemove = 
transactionalIdByPartition(topicPartition.partition)
-              transactionMetadataCache.get(topicPartition.partition).foreach { 
txnMetadataCacheEntry =>
-                toRemove.foreach { idCoordinatorEpochAndMetadata =>
-                  val transactionalId = 
idCoordinatorEpochAndMetadata.transactionalId
-                  val txnMetadata = 
txnMetadataCacheEntry.metadataPerTransactionalId.get(transactionalId)
-                  txnMetadata.inLock {
-                    if (txnMetadataCacheEntry.coordinatorEpoch == 
idCoordinatorEpochAndMetadata.coordinatorEpoch
-                      && txnMetadata.pendingState.contains(Dead)
-                      && txnMetadata.producerEpoch == 
idCoordinatorEpochAndMetadata.transitMetadata.producerEpoch
-                      && response.error == Errors.NONE) {
-                      
txnMetadataCacheEntry.metadataPerTransactionalId.remove(transactionalId)
-                    } else {
-                      warn(s"Failed to remove expired transactionalId: 
$transactionalId" +
-                        s" from cache. Tombstone append error code: 
${response.error}," +
-                        s" pendingState: ${txnMetadata.pendingState}, 
producerEpoch: ${txnMetadata.producerEpoch}," +
-                        s" expected producerEpoch: 
${idCoordinatorEpochAndMetadata.transitMetadata.producerEpoch}," +
-                        s" coordinatorEpoch: 
${txnMetadataCacheEntry.coordinatorEpoch}, expected coordinatorEpoch: " +
-                        s"${idCoordinatorEpochAndMetadata.coordinatorEpoch}")
-                      txnMetadata.pendingState = None
-                    }
-                  }
+        case None =>
+          warn(s"Transaction expiration for partition $transactionPartition 
failed because the log " +
+            "config was not available, which likely means the partition is not 
online or is no longer local.")
+      }
+    }
+  }
+
+  private def shouldExpire(
+    txnMetadata: TransactionMetadata,
+    currentTimeMs: Long
+  ): Boolean = {
+    txnMetadata.state.isExpirationAllowed &&
+      txnMetadata.txnLastUpdateTimestamp <= currentTimeMs - 
config.transactionalIdExpirationMs
+  }
+
+  private def maybeAppendExpiration(
+    txnMetadata: TransactionMetadata,
+    recordsBuilder: MemoryRecordsBuilder,
+    currentTimeMs: Long,
+  ): Boolean = {
+    val keyBytes = TransactionLog.keyToBytes(txnMetadata.transactionalId)
+    if (recordsBuilder.hasRoomFor(currentTimeMs, keyBytes, null, 
Record.EMPTY_HEADERS)) {
+      recordsBuilder.append(currentTimeMs, keyBytes, null, 
Record.EMPTY_HEADERS)
+      true
+    } else {
+      false
+    }
+  }
+
+  private[transaction] def removeExpiredTransactionalIds(): Unit = {
+    inReadLock(stateLock) {
+      transactionMetadataCache.forKeyValue { (partitionId, 
partitionCacheEntry) =>
+        val transactionPartition = new 
TopicPartition(Topic.TRANSACTION_STATE_TOPIC_NAME, partitionId)
+        removeExpiredTransactionalIds(transactionPartition, 
partitionCacheEntry)
+      }
+    }
+  }
+
+  private def writeTombstonesForExpiredTransactionalIds(
+    transactionPartition: TopicPartition,
+    expiredForPartition: Iterable[TransactionalIdCoordinatorEpochAndMetadata],
+    tombstoneRecords: MemoryRecords
+  ): Unit = {
+    def removeFromCacheCallback(responses: collection.Map[TopicPartition, 
PartitionResponse]): Unit = {
+      responses.forKeyValue { (topicPartition, response) =>
+        inReadLock(stateLock) {
+          transactionMetadataCache.get(topicPartition.partition).foreach { 
txnMetadataCacheEntry =>
+            expiredForPartition.foreach { idCoordinatorEpochAndMetadata =>
+              val transactionalId = 
idCoordinatorEpochAndMetadata.transactionalId
+              val txnMetadata = 
txnMetadataCacheEntry.metadataPerTransactionalId.get(transactionalId)
+              txnMetadata.inLock {
+                if (txnMetadataCacheEntry.coordinatorEpoch == 
idCoordinatorEpochAndMetadata.coordinatorEpoch
+                  && txnMetadata.pendingState.contains(Dead)
+                  && txnMetadata.producerEpoch == 
idCoordinatorEpochAndMetadata.transitMetadata.producerEpoch
+                  && response.error == Errors.NONE) {
+                  
txnMetadataCacheEntry.metadataPerTransactionalId.remove(transactionalId)
+                } else {
+                  warn(s"Failed to remove expired transactionalId: 
$transactionalId" +
+                    s" from cache. Tombstone append error code: 
${response.error}," +
+                    s" pendingState: ${txnMetadata.pendingState}, 
producerEpoch: ${txnMetadata.producerEpoch}," +
+                    s" expected producerEpoch: 
${idCoordinatorEpochAndMetadata.transitMetadata.producerEpoch}," +
+                    s" coordinatorEpoch: 
${txnMetadataCacheEntry.coordinatorEpoch}, expected coordinatorEpoch: " +
+                    s"${idCoordinatorEpochAndMetadata.coordinatorEpoch}")
+                  txnMetadata.pendingState = None
                 }
               }
             }
           }
         }
-
-        replicaManager.appendRecords(
-          config.requestTimeoutMs,
-          TransactionLog.EnforcedRequiredAcks,
-          internalTopicsAllowed = true,
-          origin = AppendOrigin.Coordinator,
-          recordsPerPartition,
-          removeFromCacheCallback,
-          requestLocal = RequestLocal.NoCaching)
       }
+    }
 
-    }, delay = config.removeExpiredTransactionalIdsIntervalMs, period = 
config.removeExpiredTransactionalIdsIntervalMs)
+    inReadLock(stateLock) {
+      replicaManager.appendRecords(
+        config.requestTimeoutMs,
+        TransactionLog.EnforcedRequiredAcks,
+        internalTopicsAllowed = true,
+        origin = AppendOrigin.Coordinator,
+        entriesPerPartition = Map(transactionPartition -> tombstoneRecords),
+        removeFromCacheCallback,
+        requestLocal = RequestLocal.NoCaching)
+    }
+  }
+
+  def enableTransactionalIdExpiration(): Unit = {
+    scheduler.schedule(
+      name = "transactionalId-expiration",
+      fun = removeExpiredTransactionalIds,
+      delay = config.removeExpiredTransactionalIdsIntervalMs,
+      period = config.removeExpiredTransactionalIdsIntervalMs
+    )
   }
 
   def getTransactionState(transactionalId: String): Either[Errors, 
Option[CoordinatorEpochAndTxnMetadata]] = {
@@ -689,7 +773,7 @@ class TransactionStateManager(brokerId: Int,
     }
   }
 
-  def startup(retrieveTransactionTopicPartitionCount: () => Int, 
enableTransactionalIdExpiration: Boolean = true): Unit = {
+  def startup(retrieveTransactionTopicPartitionCount: () => Int, 
enableTransactionalIdExpiration: Boolean): Unit = {
     this.retrieveTransactionTopicPartitionCount = 
retrieveTransactionTopicPartitionCount
     transactionTopicPartitionCount = retrieveTransactionTopicPartitionCount()
     if (enableTransactionalIdExpiration)
diff --git a/core/src/main/scala/kafka/utils/Pool.scala 
b/core/src/main/scala/kafka/utils/Pool.scala
index d64ff5d..84bedc1 100644
--- a/core/src/main/scala/kafka/utils/Pool.scala
+++ b/core/src/main/scala/kafka/utils/Pool.scala
@@ -80,7 +80,16 @@ class Pool[K,V](valueFactory: Option[K => V] = None) extends 
Iterable[(K, V)] {
   def foreachEntry(f: (K, V) => Unit): Unit = {
     pool.forEach((k, v) => f(k, v))
   }
-  
+
+  def foreachWhile(f: (K, V) => Boolean): Unit = {
+    val iter = pool.entrySet().iterator()
+    var finished = false
+    while (!finished && iter.hasNext) {
+      val entry = iter.next
+      finished = !f(entry.getKey, entry.getValue)
+    }
+  }
+
   override def size: Int = pool.size
   
   override def iterator: Iterator[(K, V)] = new Iterator[(K,V)]() {
diff --git 
a/core/src/test/scala/unit/kafka/coordinator/AbstractCoordinatorConcurrencyTest.scala
 
b/core/src/test/scala/unit/kafka/coordinator/AbstractCoordinatorConcurrencyTest.scala
index 3fee14e..fac34a1 100644
--- 
a/core/src/test/scala/unit/kafka/coordinator/AbstractCoordinatorConcurrencyTest.scala
+++ 
b/core/src/test/scala/unit/kafka/coordinator/AbstractCoordinatorConcurrencyTest.scala
@@ -21,8 +21,9 @@ import java.util.concurrent.{ConcurrentHashMap, Executors}
 import java.util.{Collections, Random}
 import java.util.concurrent.atomic.AtomicInteger
 import java.util.concurrent.locks.Lock
+
 import kafka.coordinator.AbstractCoordinatorConcurrencyTest._
-import kafka.log.{AppendOrigin, Log}
+import kafka.log.{AppendOrigin, Log, LogConfig}
 import kafka.server._
 import kafka.utils._
 import kafka.utils.timer.MockTimer
@@ -221,6 +222,10 @@ object AbstractCoordinatorConcurrencyTest {
       getOrCreateLogs().put(topicPartition, (log, endOffset))
     }
 
+    override def getLogConfig(topicPartition: TopicPartition): 
Option[LogConfig] = {
+      getOrCreateLogs().get(topicPartition).map(_._1.config)
+    }
+
     override def getLog(topicPartition: TopicPartition): Option[Log] =
       getOrCreateLogs().get(topicPartition).map(l => l._1)
 
diff --git 
a/core/src/test/scala/unit/kafka/coordinator/transaction/TransactionCoordinatorConcurrencyTest.scala
 
b/core/src/test/scala/unit/kafka/coordinator/transaction/TransactionCoordinatorConcurrencyTest.scala
index e02c2fe3..3727a2a 100644
--- 
a/core/src/test/scala/unit/kafka/coordinator/transaction/TransactionCoordinatorConcurrencyTest.scala
+++ 
b/core/src/test/scala/unit/kafka/coordinator/transaction/TransactionCoordinatorConcurrencyTest.scala
@@ -17,11 +17,13 @@
 package kafka.coordinator.transaction
 
 import java.nio.ByteBuffer
+import java.util.Collections
 import java.util.concurrent.atomic.AtomicBoolean
+
 import kafka.coordinator.AbstractCoordinatorConcurrencyTest
 import kafka.coordinator.AbstractCoordinatorConcurrencyTest._
 import kafka.coordinator.transaction.TransactionCoordinatorConcurrencyTest._
-import kafka.log.Log
+import kafka.log.{Log, LogConfig}
 import kafka.server.{FetchDataInfo, FetchLogEnd, KafkaConfig, 
LogOffsetMetadata, MetadataCache, RequestLocal}
 import kafka.utils.{Pool, TestUtils}
 import org.apache.kafka.clients.{ClientResponse, NetworkClient}
@@ -72,7 +74,8 @@ class TransactionCoordinatorConcurrencyTest extends 
AbstractCoordinatorConcurren
 
     txnStateManager = new TransactionStateManager(0, scheduler, 
replicaManager, txnConfig, time,
       new Metrics())
-    txnStateManager.startup(() => 
zkClient.getTopicPartitionCount(TRANSACTION_STATE_TOPIC_NAME).get)
+    txnStateManager.startup(() => 
zkClient.getTopicPartitionCount(TRANSACTION_STATE_TOPIC_NAME).get,
+      enableTransactionalIdExpiration = true)
     for (i <- 0 until numPartitions)
       txnStateManager.addLoadedTransactionsToCache(i, coordinatorEpoch, new 
Pool[String, TransactionMetadata]())
 
@@ -455,8 +458,9 @@ class TransactionCoordinatorConcurrencyTest extends 
AbstractCoordinatorConcurren
   }
 
   private def prepareTxnLog(partitionId: Int): Unit = {
-
     val logMock: Log =  EasyMock.mock(classOf[Log])
+    EasyMock.expect(logMock.config).andStubReturn(new 
LogConfig(Collections.emptyMap()))
+
     val fileRecordsMock: FileRecords = EasyMock.mock(classOf[FileRecords])
 
     val topicPartition = new TopicPartition(TRANSACTION_STATE_TOPIC_NAME, 
partitionId)
diff --git 
a/core/src/test/scala/unit/kafka/coordinator/transaction/TransactionStateManagerTest.scala
 
b/core/src/test/scala/unit/kafka/coordinator/transaction/TransactionStateManagerTest.scala
index 410d6e2..21629bd 100644
--- 
a/core/src/test/scala/unit/kafka/coordinator/transaction/TransactionStateManagerTest.scala
+++ 
b/core/src/test/scala/unit/kafka/coordinator/transaction/TransactionStateManagerTest.scala
@@ -20,8 +20,9 @@ import java.lang.management.ManagementFactory
 import java.nio.ByteBuffer
 import java.util.concurrent.CountDownLatch
 import java.util.concurrent.locks.ReentrantLock
+
 import javax.management.ObjectName
-import kafka.log.{AppendOrigin, Log}
+import kafka.log.{AppendOrigin, Defaults, Log, LogConfig}
 import kafka.server.{FetchDataInfo, FetchLogEnd, LogOffsetMetadata, 
ReplicaManager, RequestLocal}
 import kafka.utils.{MockScheduler, Pool, TestUtils}
 import kafka.zk.KafkaZkClient
@@ -37,8 +38,8 @@ import org.easymock.{Capture, EasyMock, IAnswer}
 import org.junit.jupiter.api.Assertions._
 import org.junit.jupiter.api.{AfterEach, BeforeEach, Test}
 
-import scala.jdk.CollectionConverters._
 import scala.collection.{Map, mutable}
+import scala.jdk.CollectionConverters._
 
 class TransactionStateManagerTest {
 
@@ -78,7 +79,7 @@ class TransactionStateManagerTest {
 
   @BeforeEach
   def setUp(): Unit = {
-    transactionManager.startup(() => numPartitions, false)
+    transactionManager.startup(() => numPartitions, 
enableTransactionalIdExpiration = false)
     // make sure the transactional id hashes to the assigning partition id
     assertEquals(partitionId, 
transactionManager.partitionFor(transactionalId1))
     assertEquals(partitionId, 
transactionManager.partitionFor(transactionalId2))
@@ -581,7 +582,7 @@ class TransactionStateManagerTest {
   }
 
   @Test
-  def shouldRemoveCompleteCommmitExpiredTransactionalIds(): Unit = {
+  def shouldRemoveCompleteCommitExpiredTransactionalIds(): Unit = {
     setupAndRunTransactionalIdExpiration(Errors.NONE, CompleteCommit)
     verifyMetadataDoesntExist(transactionalId1)
     verifyMetadataDoesExistAndIsUsable(transactionalId2)
@@ -630,6 +631,157 @@ class TransactionStateManagerTest {
   }
 
   @Test
+  def testTransactionalExpirationWithTooSmallBatchSize(): Unit = {
+    // The batch size is too small, but we nevertheless expect the
+    // coordinator to attempt the append. This test mainly ensures
+    // that the expiration task does not get stuck.
+
+    val partitionIds = 0 until numPartitions
+    val maxBatchSize = 16
+
+    loadTransactionsForPartitions(partitionIds)
+    val allTransactionalIds = loadExpiredTransactionalIds(numTransactionalIds 
= 20)
+
+    EasyMock.reset(replicaManager)
+    expectLogConfig(partitionIds, maxBatchSize)
+
+    val attemptedAppends = mutable.Map.empty[TopicPartition, 
mutable.Buffer[MemoryRecords]]
+    expectTransactionalIdExpiration(Errors.MESSAGE_TOO_LARGE, attemptedAppends)
+    EasyMock.replay(replicaManager)
+
+    assertEquals(allTransactionalIds, listExpirableTransactionalIds())
+    transactionManager.removeExpiredTransactionalIds()
+    EasyMock.verify(replicaManager)
+
+    for (batches <- attemptedAppends.values; batch <- batches) {
+      assertTrue(batch.sizeInBytes() > maxBatchSize)
+    }
+
+    assertEquals(allTransactionalIds, listExpirableTransactionalIds())
+  }
+
+  @Test
+  def testTransactionalExpirationWithOfflineLogDir(): Unit = {
+    val onlinePartitionId = 0
+    val offlinePartitionId = 1
+
+    val partitionIds = Seq(onlinePartitionId, offlinePartitionId)
+    val maxBatchSize = 512
+
+    loadTransactionsForPartitions(partitionIds)
+    val allTransactionalIds = loadExpiredTransactionalIds(numTransactionalIds 
= 20)
+
+    EasyMock.reset(replicaManager)
+
+    // Partition 0 returns log config as normal
+    expectLogConfig(Seq(onlinePartitionId), maxBatchSize)
+    // No log config returned for partition 0 since it is offline
+    EasyMock.expect(replicaManager.getLogConfig(new 
TopicPartition(TRANSACTION_STATE_TOPIC_NAME, offlinePartitionId)))
+      .andStubReturn(None)
+
+    val appendedRecords = mutable.Map.empty[TopicPartition, 
mutable.Buffer[MemoryRecords]]
+    expectTransactionalIdExpiration(Errors.NONE, appendedRecords)
+    EasyMock.replay(replicaManager)
+
+    assertEquals(allTransactionalIds, listExpirableTransactionalIds())
+    transactionManager.removeExpiredTransactionalIds()
+    EasyMock.verify(replicaManager)
+
+    assertEquals(Set(onlinePartitionId), 
appendedRecords.keySet.map(_.partition))
+
+    val (transactionalIdsForOnlinePartition, 
transactionalIdsForOfflinePartition) =
+      allTransactionalIds.partition { transactionalId =>
+        transactionManager.partitionFor(transactionalId) == onlinePartitionId
+      }
+
+    val expiredTransactionalIds = 
collectTransactionalIdsFromTombstones(appendedRecords)
+    assertEquals(transactionalIdsForOnlinePartition, expiredTransactionalIds)
+    assertEquals(transactionalIdsForOfflinePartition, 
listExpirableTransactionalIds())
+  }
+
+  @Test
+  def testTransactionExpirationShouldRespectBatchSize(): Unit = {
+    val partitionIds = 0 until numPartitions
+    val maxBatchSize = 512
+
+    loadTransactionsForPartitions(partitionIds)
+    val allTransactionalIds = loadExpiredTransactionalIds(numTransactionalIds 
= 1000)
+
+    EasyMock.reset(replicaManager)
+    expectLogConfig(partitionIds, maxBatchSize)
+
+    val appendedRecords = mutable.Map.empty[TopicPartition, 
mutable.Buffer[MemoryRecords]]
+    expectTransactionalIdExpiration(Errors.NONE, appendedRecords)
+    EasyMock.replay(replicaManager)
+
+    assertEquals(allTransactionalIds, listExpirableTransactionalIds())
+    transactionManager.removeExpiredTransactionalIds()
+    EasyMock.verify(replicaManager)
+
+    assertEquals(Set.empty, listExpirableTransactionalIds())
+    assertEquals(partitionIds.toSet, appendedRecords.keys.map(_.partition))
+
+    appendedRecords.values.foreach { batches =>
+      assertTrue(batches.size > 1) // Ensure a non-trivial test case
+      assertTrue(batches.forall(_.sizeInBytes() < maxBatchSize))
+    }
+
+    val expiredTransactionalIds = 
collectTransactionalIdsFromTombstones(appendedRecords)
+    assertEquals(allTransactionalIds, expiredTransactionalIds)
+  }
+
+  private def collectTransactionalIdsFromTombstones(
+    appendedRecords: mutable.Map[TopicPartition, mutable.Buffer[MemoryRecords]]
+  ): Set[String] = {
+    val expiredTransactionalIds = mutable.Set.empty[String]
+    appendedRecords.values.foreach { batches =>
+      batches.foreach { records =>
+        records.records.forEach { record =>
+          val transactionalId = 
TransactionLog.readTxnRecordKey(record.key).transactionalId
+          assertNull(record.value)
+          expiredTransactionalIds += transactionalId
+          assertEquals(Right(None), 
transactionManager.getTransactionState(transactionalId))
+        }
+      }
+    }
+    expiredTransactionalIds.toSet
+  }
+
+  private def loadExpiredTransactionalIds(
+    numTransactionalIds: Int
+  ): Set[String] = {
+    val allTransactionalIds = mutable.Set.empty[String]
+    for (i <- 0 to numTransactionalIds) {
+      val txnlId = s"id_$i"
+      val producerId = i
+      val txnMetadata = transactionMetadata(txnlId, producerId)
+      txnMetadata.txnLastUpdateTimestamp = time.milliseconds() - 
txnConfig.transactionalIdExpirationMs
+      transactionManager.putTransactionStateIfNotExists(txnMetadata)
+      allTransactionalIds += txnlId
+    }
+    allTransactionalIds.toSet
+  }
+
+  private def listExpirableTransactionalIds(): Set[String] = {
+    val activeTransactionalIds = 
transactionManager.listTransactionStates(Set.empty, Set.empty)
+      .transactionStates
+      .asScala
+      .map(_.transactionalId)
+
+    activeTransactionalIds.filter { transactionalId =>
+      transactionManager.getTransactionState(transactionalId) match {
+        case Right(Some(epochAndMetadata)) =>
+          val txnMetadata = epochAndMetadata.transactionMetadata
+          val timeSinceLastUpdate = time.milliseconds() - 
txnMetadata.txnLastUpdateTimestamp
+          timeSinceLastUpdate >= txnConfig.transactionalIdExpirationMs &&
+            txnMetadata.state.isExpirationAllowed &&
+            txnMetadata.pendingState.isEmpty
+        case _ => false
+      }
+    }.toSet
+  }
+
+  @Test
   def testSuccessfulReimmigration(): Unit = {
     txnMetadata1.state = PrepareCommit
     txnMetadata1.addPartitions(Set[TopicPartition](new 
TopicPartition("topic1", 0),
@@ -701,36 +853,66 @@ class TransactionStateManagerTest {
     }
   }
 
-  private def setupAndRunTransactionalIdExpiration(error: Errors, txnState: 
TransactionState): Unit = {
-    for (partitionId <- 0 until numPartitions) {
+  private def expectTransactionalIdExpiration(
+    appendError: Errors,
+    capturedAppends: mutable.Map[TopicPartition, mutable.Buffer[MemoryRecords]]
+  ): Unit = {
+    val recordsCapture: Capture[Map[TopicPartition, MemoryRecords]] = 
EasyMock.newCapture()
+    val callbackCapture: Capture[Map[TopicPartition, PartitionResponse] => 
Unit] = EasyMock.newCapture()
+
+    EasyMock.expect(replicaManager.appendRecords(
+      EasyMock.anyLong(),
+      EasyMock.eq((-1).toShort),
+      EasyMock.eq(true),
+      EasyMock.eq(AppendOrigin.Coordinator),
+      EasyMock.capture(recordsCapture),
+      EasyMock.capture(callbackCapture),
+      EasyMock.anyObject().asInstanceOf[Option[ReentrantLock]],
+      EasyMock.anyObject(),
+      EasyMock.anyObject()
+    )).andAnswer(() => callbackCapture.getValue.apply(
+      recordsCapture.getValue.map { case (topicPartition, records) =>
+        val batches = capturedAppends.getOrElse(topicPartition, {
+          val batches = mutable.Buffer.empty[MemoryRecords]
+          capturedAppends += topicPartition -> batches
+          batches
+        })
+
+        batches += records
+
+        topicPartition -> new PartitionResponse(appendError, 0L, 
RecordBatch.NO_TIMESTAMP, 0L)
+      }.toMap
+    )).anyTimes()
+  }
+
+  private def loadTransactionsForPartitions(
+    partitionIds: Seq[Int],
+  ): Unit = {
+    for (partitionId <- partitionIds) {
       transactionManager.addLoadedTransactionsToCache(partitionId, 0, new 
Pool[String, TransactionMetadata]())
     }
+  }
 
-    val capturedArgument: Capture[Map[TopicPartition, PartitionResponse] => 
Unit] = EasyMock.newCapture()
+  private def expectLogConfig(
+    partitionIds: Seq[Int],
+    maxBatchSize: Int
+  ): Unit = {
+    val logConfig: LogConfig = EasyMock.mock(classOf[LogConfig])
+    EasyMock.expect(logConfig.maxMessageSize).andStubReturn(maxBatchSize)
 
-    val partition = new TopicPartition(TRANSACTION_STATE_TOPIC_NAME, 
transactionManager.partitionFor(transactionalId1))
-    val recordsByPartition = Map(partition -> 
MemoryRecords.withRecords(TransactionLog.EnforcedCompressionType,
-      new SimpleRecord(time.milliseconds() + 
txnConfig.removeExpiredTransactionalIdsIntervalMs, 
TransactionLog.keyToBytes(transactionalId1), null)))
-
-    txnState match {
-      case Empty | CompleteCommit | CompleteAbort =>
-
-        EasyMock.expect(replicaManager.appendRecords(EasyMock.anyLong(),
-          EasyMock.eq((-1).toShort),
-          EasyMock.eq(true),
-          EasyMock.eq(AppendOrigin.Coordinator),
-          EasyMock.eq(recordsByPartition),
-          EasyMock.capture(capturedArgument),
-          EasyMock.anyObject().asInstanceOf[Option[ReentrantLock]],
-          EasyMock.anyObject(),
-          EasyMock.anyObject()
-        )).andAnswer(() => capturedArgument.getValue.apply(
-          Map(partition -> new PartitionResponse(error, 0L, 
RecordBatch.NO_TIMESTAMP, 0L)))
-        )
-      case _ => // shouldn't append
+    for (partitionId <- partitionIds) {
+      EasyMock.expect(replicaManager.getLogConfig(new 
TopicPartition(TRANSACTION_STATE_TOPIC_NAME, partitionId)))
+        .andStubReturn(Some(logConfig))
     }
 
-    EasyMock.replay(replicaManager)
+    EasyMock.replay(logConfig)
+  }
+
+  private def setupAndRunTransactionalIdExpiration(error: Errors, txnState: 
TransactionState): Unit = {
+    val partitionIds = 0 until numPartitions
+
+    loadTransactionsForPartitions(partitionIds)
+    expectLogConfig(partitionIds, Defaults.MaxMessageSize)
 
     txnMetadata1.txnLastUpdateTimestamp = time.milliseconds() - 
txnConfig.transactionalIdExpirationMs
     txnMetadata1.state = txnState
@@ -739,12 +921,28 @@ class TransactionStateManagerTest {
     txnMetadata2.txnLastUpdateTimestamp = time.milliseconds()
     transactionManager.putTransactionStateIfNotExists(txnMetadata2)
 
-    transactionManager.enableTransactionalIdExpiration()
-    time.sleep(txnConfig.removeExpiredTransactionalIdsIntervalMs)
-
-    scheduler.tick()
+    val appendedRecords = mutable.Map.empty[TopicPartition, 
mutable.Buffer[MemoryRecords]]
+    expectTransactionalIdExpiration(error, appendedRecords)
 
+    EasyMock.replay(replicaManager)
+    transactionManager.removeExpiredTransactionalIds()
     EasyMock.verify(replicaManager)
+
+    val stateAllowsExpiration = txnState match {
+      case Empty | CompleteCommit | CompleteAbort => true
+      case _ => false
+    }
+
+    if (stateAllowsExpiration) {
+      val partitionId = transactionManager.partitionFor(transactionalId1)
+      val topicPartition = new TopicPartition(TRANSACTION_STATE_TOPIC_NAME, 
partitionId)
+      val expectedTombstone = new SimpleRecord(time.milliseconds(), 
TransactionLog.keyToBytes(transactionalId1), null)
+      val expectedRecords = 
MemoryRecords.withRecords(TransactionLog.EnforcedCompressionType, 
expectedTombstone)
+      assertEquals(Set(topicPartition), appendedRecords.keySet)
+      assertEquals(Seq(expectedRecords), appendedRecords(topicPartition).toSeq)
+    } else {
+      assertEquals(Map.empty, appendedRecords)
+    }
   }
 
   private def verifyWritesTxnMarkersInPrepareState(state: TransactionState): 
Unit = {

Reply via email to