This is an automated email from the ASF dual-hosted git repository.

jgus pushed a commit to branch 2.8
in repository https://gitbox.apache.org/repos/asf/kafka.git


The following commit(s) were added to refs/heads/2.8 by this push:
     new facf5e9  KAFKA-13099; Transactional expiration should account for max 
batch size (#11098)
facf5e9 is described below

commit facf5e9be9e39fc950d33641c9817fa7c54b0258
Author: Jason Gustafson <[email protected]>
AuthorDate: Tue Jul 27 18:23:00 2021 -0700

    KAFKA-13099; Transactional expiration should account for max batch size 
(#11098)
    
    When expiring transactionalIds, we group the tombstones together into 
batches. Currently there is no limit on the size of these batches, which can 
lead to `MESSAGE_TOO_LARGE` errors when a bunch of transactionalIds need to be 
expired at the same time. This patch fixes the problem by ensuring that the 
batch size respects the configured limit. Any transactionalIds which are 
eligible for expiration and cannot be fit into the batch are postponed until 
the next periodic check.
    
    Reviewers: David Jacot <[email protected]>, Guozhang Wang 
<[email protected]>
---
 .../apache/kafka/common/record/MemoryRecords.java  |  14 ++
 .../transaction/TransactionMetadata.scala          |  20 +-
 .../transaction/TransactionStateManager.scala      | 205 +++++++++++-----
 core/src/main/scala/kafka/utils/Pool.scala         |  11 +-
 .../AbstractCoordinatorConcurrencyTest.scala       |   7 +-
 .../TransactionCoordinatorConcurrencyTest.scala    |   9 +-
 .../transaction/TransactionStateManagerTest.scala  | 261 ++++++++++++++++++---
 7 files changed, 425 insertions(+), 102 deletions(-)

diff --git 
a/clients/src/main/java/org/apache/kafka/common/record/MemoryRecords.java 
b/clients/src/main/java/org/apache/kafka/common/record/MemoryRecords.java
index ae844bf..d28049c 100644
--- a/clients/src/main/java/org/apache/kafka/common/record/MemoryRecords.java
+++ b/clients/src/main/java/org/apache/kafka/common/record/MemoryRecords.java
@@ -406,6 +406,20 @@ public class MemoryRecords extends AbstractRecords {
         return builder(buffer, RecordBatch.CURRENT_MAGIC_VALUE, 
compressionType, timestampType, baseOffset);
     }
 
+    public static MemoryRecordsBuilder builder(ByteBuffer buffer,
+                                               CompressionType compressionType,
+                                               TimestampType timestampType,
+                                               long baseOffset,
+                                               int maxSize) {
+        long logAppendTime = RecordBatch.NO_TIMESTAMP;
+        if (timestampType == TimestampType.LOG_APPEND_TIME)
+            logAppendTime = System.currentTimeMillis();
+
+        return new MemoryRecordsBuilder(buffer, 
RecordBatch.CURRENT_MAGIC_VALUE, compressionType, timestampType, baseOffset,
+            logAppendTime, RecordBatch.NO_PRODUCER_ID, 
RecordBatch.NO_PRODUCER_EPOCH, RecordBatch.NO_SEQUENCE,
+            false, false, RecordBatch.NO_PARTITION_LEADER_EPOCH, maxSize);
+    }
+
     public static MemoryRecordsBuilder idempotentBuilder(ByteBuffer buffer,
                                                          CompressionType 
compressionType,
                                                          long baseOffset,
diff --git 
a/core/src/main/scala/kafka/coordinator/transaction/TransactionMetadata.scala 
b/core/src/main/scala/kafka/coordinator/transaction/TransactionMetadata.scala
index e059b04..aba6cba 100644
--- 
a/core/src/main/scala/kafka/coordinator/transaction/TransactionMetadata.scala
+++ 
b/core/src/main/scala/kafka/coordinator/transaction/TransactionMetadata.scala
@@ -25,7 +25,10 @@ import org.apache.kafka.common.record.RecordBatch
 
 import scala.collection.{immutable, mutable}
 
-private[transaction] sealed trait TransactionState { def byte: Byte }
+private[transaction] sealed trait TransactionState {
+  def byte: Byte
+  def isExpirationAllowed: Boolean = false
+}
 
 /**
  * Transaction has not existed yet
@@ -33,7 +36,10 @@ private[transaction] sealed trait TransactionState { def 
byte: Byte }
  * transition: received AddPartitionsToTxnRequest => Ongoing
  *             received AddOffsetsToTxnRequest => Ongoing
  */
-private[transaction] case object Empty extends TransactionState { val byte: 
Byte = 0 }
+private[transaction] case object Empty extends TransactionState {
+  val byte: Byte = 0
+  override def isExpirationAllowed: Boolean = true
+}
 
 /**
  * Transaction has started and ongoing
@@ -64,14 +70,20 @@ private[transaction] case object PrepareAbort extends 
TransactionState { val byt
  *
  * Will soon be removed from the ongoing transaction cache
  */
-private[transaction] case object CompleteCommit extends TransactionState { val 
byte: Byte = 4 }
+private[transaction] case object CompleteCommit extends TransactionState {
+  val byte: Byte = 4
+  override def isExpirationAllowed: Boolean = true
+}
 
 /**
  * Group has completed abort
  *
  * Will soon be removed from the ongoing transaction cache
  */
-private[transaction] case object CompleteAbort extends TransactionState { val 
byte: Byte = 5 }
+private[transaction] case object CompleteAbort extends TransactionState {
+  val byte: Byte = 5
+  override def isExpirationAllowed: Boolean = true
+}
 
 /**
   * TransactionalId has expired and is about to be removed from the 
transaction cache
diff --git 
a/core/src/main/scala/kafka/coordinator/transaction/TransactionStateManager.scala
 
b/core/src/main/scala/kafka/coordinator/transaction/TransactionStateManager.scala
index b547896..3e5413b 100644
--- 
a/core/src/main/scala/kafka/coordinator/transaction/TransactionStateManager.scala
+++ 
b/core/src/main/scala/kafka/coordinator/transaction/TransactionStateManager.scala
@@ -32,7 +32,7 @@ import org.apache.kafka.common.internals.Topic
 import org.apache.kafka.common.metrics.Metrics
 import org.apache.kafka.common.metrics.stats.{Avg, Max}
 import org.apache.kafka.common.protocol.Errors
-import org.apache.kafka.common.record.{FileRecords, MemoryRecords, 
SimpleRecord}
+import org.apache.kafka.common.record.{FileRecords, MemoryRecords, 
MemoryRecordsBuilder, Record, SimpleRecord, TimestampType}
 import org.apache.kafka.common.requests.ProduceResponse.PartitionResponse
 import org.apache.kafka.common.requests.TransactionResult
 import org.apache.kafka.common.utils.{Time, Utils}
@@ -140,78 +140,161 @@ class TransactionStateManager(brokerId: Int,
     }
   }
 
-  def enableTransactionalIdExpiration(): Unit = {
-    scheduler.schedule("transactionalId-expiration", () => {
-      val now = time.milliseconds()
-      inReadLock(stateLock) {
-        val transactionalIdByPartition: Map[Int, 
mutable.Iterable[TransactionalIdCoordinatorEpochAndMetadata]] =
-          transactionMetadataCache.flatMap { case (_, entry) =>
-            entry.metadataPerTransactionalId.filter { case (_, txnMetadata) => 
txnMetadata.state match {
-              case Empty | CompleteCommit | CompleteAbort => true
-              case _ => false
-            }
-            }.filter { case (_, txnMetadata) =>
-              txnMetadata.txnLastUpdateTimestamp <= now - 
config.transactionalIdExpirationMs
-            }.map { case (transactionalId, txnMetadata) =>
-              val txnMetadataTransition = txnMetadata.inLock {
-                txnMetadata.prepareDead()
+  private def removeExpiredTransactionalIds(
+    transactionPartition: TopicPartition,
+    txnMetadataCacheEntry: TxnMetadataCacheEntry,
+  ): Unit = {
+    inReadLock(stateLock) {
+      replicaManager.getLogConfig(transactionPartition) match {
+        case Some(logConfig) =>
+          val currentTimeMs = time.milliseconds()
+          val maxBatchSize = logConfig.maxMessageSize
+          val expired = 
mutable.ListBuffer.empty[TransactionalIdCoordinatorEpochAndMetadata]
+          var recordsBuilder: MemoryRecordsBuilder = null
+          val stateEntries = 
txnMetadataCacheEntry.metadataPerTransactionalId.values.iterator.buffered
+
+          def flushRecordsBuilder(): Unit = {
+            writeTombstonesForExpiredTransactionalIds(
+              transactionPartition,
+              expired.toSeq,
+              recordsBuilder.build()
+            )
+            expired.clear()
+            recordsBuilder = null
+          }
+
+          while (stateEntries.hasNext) {
+            val txnMetadata = stateEntries.head
+            val transactionalId = txnMetadata.transactionalId
+            var fullBatch = false
+
+            txnMetadata.inLock {
+              if (txnMetadata.pendingState.isEmpty && 
shouldExpire(txnMetadata, currentTimeMs)) {
+                if (recordsBuilder == null) {
+                  recordsBuilder = MemoryRecords.builder(
+                    ByteBuffer.allocate(math.min(16384, maxBatchSize)),
+                    TransactionLog.EnforcedCompressionType,
+                    TimestampType.CREATE_TIME,
+                    0L,
+                    maxBatchSize
+                  )
+                }
+
+                if (maybeAppendExpiration(txnMetadata, recordsBuilder, 
currentTimeMs)) {
+                  val transitMetadata = txnMetadata.prepareDead()
+                  expired += TransactionalIdCoordinatorEpochAndMetadata(
+                    transactionalId,
+                    txnMetadataCacheEntry.coordinatorEpoch,
+                    transitMetadata
+                  )
+                } else {
+                  fullBatch = true
+                }
               }
-              TransactionalIdCoordinatorEpochAndMetadata(transactionalId, 
entry.coordinatorEpoch, txnMetadataTransition)
             }
-          }.groupBy { transactionalIdCoordinatorEpochAndMetadata =>
-            
partitionFor(transactionalIdCoordinatorEpochAndMetadata.transactionalId)
+
+            if (fullBatch) {
+              flushRecordsBuilder()
+            } else {
+              // Advance the iterator if we do not need to retry the append
+              stateEntries.next()
+            }
           }
 
-        val recordsPerPartition = transactionalIdByPartition
-          .map { case (partition, transactionalIdCoordinatorEpochAndMetadatas) 
=>
-            val deletes: Array[SimpleRecord] = 
transactionalIdCoordinatorEpochAndMetadatas.map { entry =>
-              new SimpleRecord(now, 
TransactionLog.keyToBytes(entry.transactionalId), null)
-            }.toArray
-            val records = 
MemoryRecords.withRecords(TransactionLog.EnforcedCompressionType, deletes: _*)
-            val topicPartition = new 
TopicPartition(Topic.TRANSACTION_STATE_TOPIC_NAME, partition)
-            (topicPartition, records)
+          if (expired.nonEmpty) {
+            flushRecordsBuilder()
           }
 
-        def removeFromCacheCallback(responses: collection.Map[TopicPartition, 
PartitionResponse]): Unit = {
-          responses.forKeyValue { (topicPartition, response) =>
-            inReadLock(stateLock) {
-              val toRemove = 
transactionalIdByPartition(topicPartition.partition)
-              transactionMetadataCache.get(topicPartition.partition).foreach { 
txnMetadataCacheEntry =>
-                toRemove.foreach { idCoordinatorEpochAndMetadata =>
-                  val transactionalId = 
idCoordinatorEpochAndMetadata.transactionalId
-                  val txnMetadata = 
txnMetadataCacheEntry.metadataPerTransactionalId.get(transactionalId)
-                  txnMetadata.inLock {
-                    if (txnMetadataCacheEntry.coordinatorEpoch == 
idCoordinatorEpochAndMetadata.coordinatorEpoch
-                      && txnMetadata.pendingState.contains(Dead)
-                      && txnMetadata.producerEpoch == 
idCoordinatorEpochAndMetadata.transitMetadata.producerEpoch
-                      && response.error == Errors.NONE) {
-                      
txnMetadataCacheEntry.metadataPerTransactionalId.remove(transactionalId)
-                    } else {
-                      warn(s"Failed to remove expired transactionalId: 
$transactionalId" +
-                        s" from cache. Tombstone append error code: 
${response.error}," +
-                        s" pendingState: ${txnMetadata.pendingState}, 
producerEpoch: ${txnMetadata.producerEpoch}," +
-                        s" expected producerEpoch: 
${idCoordinatorEpochAndMetadata.transitMetadata.producerEpoch}," +
-                        s" coordinatorEpoch: 
${txnMetadataCacheEntry.coordinatorEpoch}, expected coordinatorEpoch: " +
-                        s"${idCoordinatorEpochAndMetadata.coordinatorEpoch}")
-                      txnMetadata.pendingState = None
-                    }
-                  }
+        case None =>
+          warn(s"Transaction expiration for partition $transactionPartition 
failed because the log " +
+            "config was not available, which likely means the partition is not 
online or is no longer local.")
+      }
+    }
+  }
+
+  private def shouldExpire(
+    txnMetadata: TransactionMetadata,
+    currentTimeMs: Long
+  ): Boolean = {
+    txnMetadata.state.isExpirationAllowed &&
+      txnMetadata.txnLastUpdateTimestamp <= currentTimeMs - 
config.transactionalIdExpirationMs
+  }
+
+  private def maybeAppendExpiration(
+    txnMetadata: TransactionMetadata,
+    recordsBuilder: MemoryRecordsBuilder,
+    currentTimeMs: Long,
+  ): Boolean = {
+    val keyBytes = TransactionLog.keyToBytes(txnMetadata.transactionalId)
+    if (recordsBuilder.hasRoomFor(currentTimeMs, keyBytes, null, 
Record.EMPTY_HEADERS)) {
+      recordsBuilder.append(currentTimeMs, keyBytes, null, 
Record.EMPTY_HEADERS)
+      true
+    } else {
+      false
+    }
+  }
+
+  private[transaction] def removeExpiredTransactionalIds(): Unit = {
+    inReadLock(stateLock) {
+      transactionMetadataCache.forKeyValue { (partitionId, 
partitionCacheEntry) =>
+        val transactionPartition = new 
TopicPartition(Topic.TRANSACTION_STATE_TOPIC_NAME, partitionId)
+        removeExpiredTransactionalIds(transactionPartition, 
partitionCacheEntry)
+      }
+    }
+  }
+
+  private def writeTombstonesForExpiredTransactionalIds(
+    transactionPartition: TopicPartition,
+    expiredForPartition: Iterable[TransactionalIdCoordinatorEpochAndMetadata],
+    tombstoneRecords: MemoryRecords
+  ): Unit = {
+    def removeFromCacheCallback(responses: collection.Map[TopicPartition, 
PartitionResponse]): Unit = {
+      responses.forKeyValue { (topicPartition, response) =>
+        inReadLock(stateLock) {
+          transactionMetadataCache.get(topicPartition.partition).foreach { 
txnMetadataCacheEntry =>
+            expiredForPartition.foreach { idCoordinatorEpochAndMetadata =>
+              val transactionalId = 
idCoordinatorEpochAndMetadata.transactionalId
+              val txnMetadata = 
txnMetadataCacheEntry.metadataPerTransactionalId.get(transactionalId)
+              txnMetadata.inLock {
+                if (txnMetadataCacheEntry.coordinatorEpoch == 
idCoordinatorEpochAndMetadata.coordinatorEpoch
+                  && txnMetadata.pendingState.contains(Dead)
+                  && txnMetadata.producerEpoch == 
idCoordinatorEpochAndMetadata.transitMetadata.producerEpoch
+                  && response.error == Errors.NONE) {
+                  
txnMetadataCacheEntry.metadataPerTransactionalId.remove(transactionalId)
+                } else {
+                  warn(s"Failed to remove expired transactionalId: 
$transactionalId" +
+                    s" from cache. Tombstone append error code: 
${response.error}," +
+                    s" pendingState: ${txnMetadata.pendingState}, 
producerEpoch: ${txnMetadata.producerEpoch}," +
+                    s" expected producerEpoch: 
${idCoordinatorEpochAndMetadata.transitMetadata.producerEpoch}," +
+                    s" coordinatorEpoch: 
${txnMetadataCacheEntry.coordinatorEpoch}, expected coordinatorEpoch: " +
+                    s"${idCoordinatorEpochAndMetadata.coordinatorEpoch}")
+                  txnMetadata.pendingState = None
                 }
               }
             }
           }
         }
-
-        replicaManager.appendRecords(
-          config.requestTimeoutMs,
-          TransactionLog.EnforcedRequiredAcks,
-          internalTopicsAllowed = true,
-          origin = AppendOrigin.Coordinator,
-          recordsPerPartition,
-          removeFromCacheCallback)
       }
+    }
 
-    }, delay = config.removeExpiredTransactionalIdsIntervalMs, period = 
config.removeExpiredTransactionalIdsIntervalMs)
+    inReadLock(stateLock) {
+      replicaManager.appendRecords(
+        config.requestTimeoutMs,
+        TransactionLog.EnforcedRequiredAcks,
+        internalTopicsAllowed = true,
+        origin = AppendOrigin.Coordinator,
+        entriesPerPartition = Map(transactionPartition -> tombstoneRecords),
+        removeFromCacheCallback)
+    }
+  }
+
+  def enableTransactionalIdExpiration(): Unit = {
+    scheduler.schedule(
+      name = "transactionalId-expiration",
+      fun = removeExpiredTransactionalIds,
+      delay = config.removeExpiredTransactionalIdsIntervalMs,
+      period = config.removeExpiredTransactionalIdsIntervalMs
+    )
   }
 
   def getTransactionState(transactionalId: String): Either[Errors, 
Option[CoordinatorEpochAndTxnMetadata]] = {
@@ -634,7 +717,7 @@ class TransactionStateManager(brokerId: Int,
     }
   }
 
-  def startup(retrieveTransactionTopicPartitionCount: () => Int, 
enableTransactionalIdExpiration: Boolean = true): Unit = {
+  def startup(retrieveTransactionTopicPartitionCount: () => Int, 
enableTransactionalIdExpiration: Boolean): Unit = {
     this.retrieveTransactionTopicPartitionCount = 
retrieveTransactionTopicPartitionCount
     transactionTopicPartitionCount = retrieveTransactionTopicPartitionCount()
     if (enableTransactionalIdExpiration)
diff --git a/core/src/main/scala/kafka/utils/Pool.scala 
b/core/src/main/scala/kafka/utils/Pool.scala
index d64ff5d..84bedc1 100644
--- a/core/src/main/scala/kafka/utils/Pool.scala
+++ b/core/src/main/scala/kafka/utils/Pool.scala
@@ -80,7 +80,16 @@ class Pool[K,V](valueFactory: Option[K => V] = None) extends 
Iterable[(K, V)] {
   def foreachEntry(f: (K, V) => Unit): Unit = {
     pool.forEach((k, v) => f(k, v))
   }
-  
+
+  def foreachWhile(f: (K, V) => Boolean): Unit = {
+    val iter = pool.entrySet().iterator()
+    var finished = false
+    while (!finished && iter.hasNext) {
+      val entry = iter.next
+      finished = !f(entry.getKey, entry.getValue)
+    }
+  }
+
   override def size: Int = pool.size
   
   override def iterator: Iterator[(K, V)] = new Iterator[(K,V)]() {
diff --git 
a/core/src/test/scala/unit/kafka/coordinator/AbstractCoordinatorConcurrencyTest.scala
 
b/core/src/test/scala/unit/kafka/coordinator/AbstractCoordinatorConcurrencyTest.scala
index 14e0f3f..d8d4c8d 100644
--- 
a/core/src/test/scala/unit/kafka/coordinator/AbstractCoordinatorConcurrencyTest.scala
+++ 
b/core/src/test/scala/unit/kafka/coordinator/AbstractCoordinatorConcurrencyTest.scala
@@ -23,7 +23,7 @@ import java.util.concurrent.atomic.AtomicInteger
 import java.util.concurrent.locks.Lock
 
 import kafka.coordinator.AbstractCoordinatorConcurrencyTest._
-import kafka.log.{AppendOrigin, Log}
+import kafka.log.{AppendOrigin, Log, LogConfig}
 import kafka.server._
 import kafka.utils._
 import kafka.utils.timer.MockTimer
@@ -216,6 +216,11 @@ object AbstractCoordinatorConcurrencyTest {
     def updateLog(topicPartition: TopicPartition, log: Log, endOffset: Long): 
Unit = {
       getOrCreateLogs().put(topicPartition, (log, endOffset))
     }
+
+    override def getLogConfig(topicPartition: TopicPartition): 
Option[LogConfig] = {
+      getOrCreateLogs().get(topicPartition).map(_._1.config)
+    }
+
     override def getLog(topicPartition: TopicPartition): Option[Log] =
       getOrCreateLogs().get(topicPartition).map(l => l._1)
     override def getLogEndOffset(topicPartition: TopicPartition): Option[Long] 
=
diff --git 
a/core/src/test/scala/unit/kafka/coordinator/transaction/TransactionCoordinatorConcurrencyTest.scala
 
b/core/src/test/scala/unit/kafka/coordinator/transaction/TransactionCoordinatorConcurrencyTest.scala
index ad33703..05f894f 100644
--- 
a/core/src/test/scala/unit/kafka/coordinator/transaction/TransactionCoordinatorConcurrencyTest.scala
+++ 
b/core/src/test/scala/unit/kafka/coordinator/transaction/TransactionCoordinatorConcurrencyTest.scala
@@ -17,12 +17,13 @@
 package kafka.coordinator.transaction
 
 import java.nio.ByteBuffer
+import java.util.Collections
 import java.util.concurrent.atomic.AtomicBoolean
 
 import kafka.coordinator.AbstractCoordinatorConcurrencyTest
 import kafka.coordinator.AbstractCoordinatorConcurrencyTest._
 import kafka.coordinator.transaction.TransactionCoordinatorConcurrencyTest._
-import kafka.log.Log
+import kafka.log.{Log, LogConfig}
 import kafka.server.{FetchDataInfo, FetchLogEnd, KafkaConfig, 
LogOffsetMetadata, MetadataCache}
 import kafka.utils.{Pool, TestUtils}
 import org.apache.kafka.clients.{ClientResponse, NetworkClient}
@@ -73,7 +74,8 @@ class TransactionCoordinatorConcurrencyTest extends 
AbstractCoordinatorConcurren
 
     txnStateManager = new TransactionStateManager(0, scheduler, 
replicaManager, txnConfig, time,
       new Metrics())
-    txnStateManager.startup(() => 
zkClient.getTopicPartitionCount(TRANSACTION_STATE_TOPIC_NAME).get)
+    txnStateManager.startup(() => 
zkClient.getTopicPartitionCount(TRANSACTION_STATE_TOPIC_NAME).get,
+      enableTransactionalIdExpiration = true)
     for (i <- 0 until numPartitions)
       txnStateManager.addLoadedTransactionsToCache(i, coordinatorEpoch, new 
Pool[String, TransactionMetadata]())
 
@@ -456,8 +458,9 @@ class TransactionCoordinatorConcurrencyTest extends 
AbstractCoordinatorConcurren
   }
 
   private def prepareTxnLog(partitionId: Int): Unit = {
-
     val logMock: Log =  EasyMock.mock(classOf[Log])
+    EasyMock.expect(logMock.config).andStubReturn(new 
LogConfig(Collections.emptyMap()))
+
     val fileRecordsMock: FileRecords = EasyMock.mock(classOf[FileRecords])
 
     val topicPartition = new TopicPartition(TRANSACTION_STATE_TOPIC_NAME, 
partitionId)
diff --git 
a/core/src/test/scala/unit/kafka/coordinator/transaction/TransactionStateManagerTest.scala
 
b/core/src/test/scala/unit/kafka/coordinator/transaction/TransactionStateManagerTest.scala
index 20ca0c4..f285abb 100644
--- 
a/core/src/test/scala/unit/kafka/coordinator/transaction/TransactionStateManagerTest.scala
+++ 
b/core/src/test/scala/unit/kafka/coordinator/transaction/TransactionStateManagerTest.scala
@@ -20,8 +20,9 @@ import java.lang.management.ManagementFactory
 import java.nio.ByteBuffer
 import java.util.concurrent.CountDownLatch
 import java.util.concurrent.locks.ReentrantLock
+
 import javax.management.ObjectName
-import kafka.log.{AppendOrigin, Log}
+import kafka.log.{AppendOrigin, Defaults, Log, LogConfig}
 import kafka.server.{FetchDataInfo, FetchLogEnd, LogOffsetMetadata, 
ReplicaManager}
 import kafka.utils.{MockScheduler, Pool, TestUtils}
 import kafka.zk.KafkaZkClient
@@ -34,11 +35,11 @@ import 
org.apache.kafka.common.requests.ProduceResponse.PartitionResponse
 import org.apache.kafka.common.requests.TransactionResult
 import org.apache.kafka.common.utils.MockTime
 import org.easymock.{Capture, EasyMock, IAnswer}
-import org.junit.jupiter.api.Assertions.{assertEquals, assertFalse, 
assertThrows, assertTrue, fail}
+import org.junit.jupiter.api.Assertions.{assertEquals, assertFalse, 
assertNull, assertThrows, assertTrue, fail}
 import org.junit.jupiter.api.{AfterEach, BeforeEach, Test}
 
-import scala.jdk.CollectionConverters._
 import scala.collection.{Map, mutable}
+import scala.jdk.CollectionConverters._
 
 class TransactionStateManagerTest {
 
@@ -78,7 +79,7 @@ class TransactionStateManagerTest {
 
   @BeforeEach
   def setUp(): Unit = {
-    transactionManager.startup(() => numPartitions, false)
+    transactionManager.startup(() => numPartitions, 
enableTransactionalIdExpiration = false)
     // make sure the transactional id hashes to the assigning partition id
     assertEquals(partitionId, 
transactionManager.partitionFor(transactionalId1))
     assertEquals(partitionId, 
transactionManager.partitionFor(transactionalId2))
@@ -512,7 +513,7 @@ class TransactionStateManagerTest {
   }
 
   @Test
-  def shouldRemoveCompleteCommmitExpiredTransactionalIds(): Unit = {
+  def shouldRemoveCompleteCommitExpiredTransactionalIds(): Unit = {
     setupAndRunTransactionalIdExpiration(Errors.NONE, CompleteCommit)
     verifyMetadataDoesntExist(transactionalId1)
     verifyMetadataDoesExistAndIsUsable(transactionalId2)
@@ -561,6 +562,156 @@ class TransactionStateManagerTest {
   }
 
   @Test
+  def testTransactionalExpirationWithTooSmallBatchSize(): Unit = {
+    // The batch size is too small, but we nevertheless expect the
+    // coordinator to attempt the append. This test mainly ensures
+    // that the expiration task does not get stuck.
+
+    val partitionIds = 0 until numPartitions
+    val maxBatchSize = 16
+
+    loadTransactionsForPartitions(partitionIds)
+    val allTransactionalIds = loadExpiredTransactionalIds(numTransactionalIds 
= 20)
+
+    EasyMock.reset(replicaManager)
+    expectLogConfig(partitionIds, maxBatchSize)
+
+    val attemptedAppends = mutable.Map.empty[TopicPartition, 
mutable.Buffer[MemoryRecords]]
+    expectTransactionalIdExpiration(Errors.MESSAGE_TOO_LARGE, attemptedAppends)
+    EasyMock.replay(replicaManager)
+
+    assertEquals(allTransactionalIds, listExpirableTransactionalIds())
+    transactionManager.removeExpiredTransactionalIds()
+    EasyMock.verify(replicaManager)
+
+    for (batches <- attemptedAppends.values; batch <- batches) {
+      assertTrue(batch.sizeInBytes() > maxBatchSize)
+    }
+
+    assertEquals(allTransactionalIds, listExpirableTransactionalIds())
+  }
+
+  @Test
+  def testTransactionalExpirationWithOfflineLogDir(): Unit = {
+    val onlinePartitionId = 0
+    val offlinePartitionId = 1
+
+    val partitionIds = Seq(onlinePartitionId, offlinePartitionId)
+    val maxBatchSize = 512
+
+    loadTransactionsForPartitions(partitionIds)
+    val allTransactionalIds = loadExpiredTransactionalIds(numTransactionalIds 
= 20)
+
+    EasyMock.reset(replicaManager)
+
+    // Partition 0 returns log config as normal
+    expectLogConfig(Seq(onlinePartitionId), maxBatchSize)
+    // No log config returned for partition 0 since it is offline
+    EasyMock.expect(replicaManager.getLogConfig(new 
TopicPartition(TRANSACTION_STATE_TOPIC_NAME, offlinePartitionId)))
+      .andStubReturn(None)
+
+    val appendedRecords = mutable.Map.empty[TopicPartition, 
mutable.Buffer[MemoryRecords]]
+    expectTransactionalIdExpiration(Errors.NONE, appendedRecords)
+    EasyMock.replay(replicaManager)
+
+    assertEquals(allTransactionalIds, listExpirableTransactionalIds())
+    transactionManager.removeExpiredTransactionalIds()
+    EasyMock.verify(replicaManager)
+
+    assertEquals(Set(onlinePartitionId), 
appendedRecords.keySet.map(_.partition))
+
+    val (transactionalIdsForOnlinePartition, 
transactionalIdsForOfflinePartition) =
+      allTransactionalIds.partition { transactionalId =>
+        transactionManager.partitionFor(transactionalId) == onlinePartitionId
+      }
+
+    val expiredTransactionalIds = 
collectTransactionalIdsFromTombstones(appendedRecords)
+    assertEquals(transactionalIdsForOnlinePartition, expiredTransactionalIds)
+    assertEquals(transactionalIdsForOfflinePartition, 
listExpirableTransactionalIds())
+  }
+
+  @Test
+  def testTransactionExpirationShouldRespectBatchSize(): Unit = {
+    val partitionIds = 0 until numPartitions
+    val maxBatchSize = 512
+
+    loadTransactionsForPartitions(partitionIds)
+    val allTransactionalIds = loadExpiredTransactionalIds(numTransactionalIds 
= 1000)
+
+    EasyMock.reset(replicaManager)
+    expectLogConfig(partitionIds, maxBatchSize)
+
+    val appendedRecords = mutable.Map.empty[TopicPartition, 
mutable.Buffer[MemoryRecords]]
+    expectTransactionalIdExpiration(Errors.NONE, appendedRecords)
+    EasyMock.replay(replicaManager)
+
+    assertEquals(allTransactionalIds, listExpirableTransactionalIds())
+    transactionManager.removeExpiredTransactionalIds()
+    EasyMock.verify(replicaManager)
+
+    assertEquals(Set.empty, listExpirableTransactionalIds())
+    assertEquals(partitionIds.toSet, appendedRecords.keys.map(_.partition))
+
+    appendedRecords.values.foreach { batches =>
+      assertTrue(batches.size > 1) // Ensure a non-trivial test case
+      assertTrue(batches.forall(_.sizeInBytes() < maxBatchSize))
+    }
+
+    val expiredTransactionalIds = 
collectTransactionalIdsFromTombstones(appendedRecords)
+    assertEquals(allTransactionalIds, expiredTransactionalIds)
+  }
+
+  private def collectTransactionalIdsFromTombstones(
+    appendedRecords: mutable.Map[TopicPartition, mutable.Buffer[MemoryRecords]]
+  ): Set[String] = {
+    val expiredTransactionalIds = mutable.Set.empty[String]
+    appendedRecords.values.foreach { batches =>
+      batches.foreach { records =>
+        records.records.forEach { record =>
+          val transactionalId = 
TransactionLog.readTxnRecordKey(record.key).transactionalId
+          assertNull(record.value)
+          expiredTransactionalIds += transactionalId
+          assertEquals(Right(None), 
transactionManager.getTransactionState(transactionalId))
+        }
+      }
+    }
+    expiredTransactionalIds.toSet
+  }
+
+  private def loadExpiredTransactionalIds(
+    numTransactionalIds: Int
+  ): Set[String] = {
+    val allTransactionalIds = mutable.Set.empty[String]
+    for (i <- 0 to numTransactionalIds) {
+      val txnlId = s"id_$i"
+      val producerId = i
+      val txnMetadata = transactionMetadata(txnlId, producerId)
+      txnMetadata.txnLastUpdateTimestamp = time.milliseconds() - 
txnConfig.transactionalIdExpirationMs
+      transactionManager.putTransactionStateIfNotExists(txnMetadata)
+      allTransactionalIds += txnlId
+    }
+    allTransactionalIds.toSet
+  }
+
+  private def listExpirableTransactionalIds(): Set[String] = {
+    val activeTransactionalIds = transactionManager.transactionMetadataCache
+        .values
+        .flatMap(_.metadataPerTransactionalId.values.map(_.transactionalId))
+
+    activeTransactionalIds.filter { transactionalId =>
+      transactionManager.getTransactionState(transactionalId) match {
+        case Right(Some(epochAndMetadata)) =>
+          val txnMetadata = epochAndMetadata.transactionMetadata
+          val timeSinceLastUpdate = time.milliseconds() - 
txnMetadata.txnLastUpdateTimestamp
+          timeSinceLastUpdate >= txnConfig.transactionalIdExpirationMs &&
+            txnMetadata.state.isExpirationAllowed &&
+            txnMetadata.pendingState.isEmpty
+        case _ => false
+      }
+    }.toSet
+  }
+
+  @Test
   def testSuccessfulReimmigration(): Unit = {
     txnMetadata1.state = PrepareCommit
     txnMetadata1.addPartitions(Set[TopicPartition](new 
TopicPartition("topic1", 0),
@@ -632,35 +783,65 @@ class TransactionStateManagerTest {
     }
   }
 
-  private def setupAndRunTransactionalIdExpiration(error: Errors, txnState: 
TransactionState): Unit = {
-    for (partitionId <- 0 until numPartitions) {
+  private def expectTransactionalIdExpiration(
+    appendError: Errors,
+    capturedAppends: mutable.Map[TopicPartition, mutable.Buffer[MemoryRecords]]
+  ): Unit = {
+    val recordsCapture: Capture[Map[TopicPartition, MemoryRecords]] = 
EasyMock.newCapture()
+    val callbackCapture: Capture[Map[TopicPartition, PartitionResponse] => 
Unit] = EasyMock.newCapture()
+
+    EasyMock.expect(replicaManager.appendRecords(
+      EasyMock.anyLong(),
+      EasyMock.eq((-1).toShort),
+      EasyMock.eq(true),
+      EasyMock.eq(AppendOrigin.Coordinator),
+      EasyMock.capture(recordsCapture),
+      EasyMock.capture(callbackCapture),
+      EasyMock.anyObject().asInstanceOf[Option[ReentrantLock]],
+      EasyMock.anyObject()
+    )).andAnswer(() => callbackCapture.getValue.apply(
+      recordsCapture.getValue.map { case (topicPartition, records) =>
+        val batches = capturedAppends.getOrElse(topicPartition, {
+          val batches = mutable.Buffer.empty[MemoryRecords]
+          capturedAppends += topicPartition -> batches
+          batches
+        })
+
+        batches += records
+
+        topicPartition -> new PartitionResponse(appendError, 0L, 
RecordBatch.NO_TIMESTAMP, 0L)
+      }.toMap
+    )).anyTimes()
+  }
+
+  private def loadTransactionsForPartitions(
+    partitionIds: Seq[Int],
+  ): Unit = {
+    for (partitionId <- partitionIds) {
       transactionManager.addLoadedTransactionsToCache(partitionId, 0, new 
Pool[String, TransactionMetadata]())
     }
+  }
 
-    val capturedArgument: Capture[Map[TopicPartition, PartitionResponse] => 
Unit] = EasyMock.newCapture()
+  private def expectLogConfig(
+    partitionIds: Seq[Int],
+    maxBatchSize: Int
+  ): Unit = {
+    val logConfig: LogConfig = EasyMock.mock(classOf[LogConfig])
+    EasyMock.expect(logConfig.maxMessageSize).andStubReturn(maxBatchSize)
 
-    val partition = new TopicPartition(TRANSACTION_STATE_TOPIC_NAME, 
transactionManager.partitionFor(transactionalId1))
-    val recordsByPartition = Map(partition -> 
MemoryRecords.withRecords(TransactionLog.EnforcedCompressionType,
-      new SimpleRecord(time.milliseconds() + 
txnConfig.removeExpiredTransactionalIdsIntervalMs, 
TransactionLog.keyToBytes(transactionalId1), null)))
-
-    txnState match {
-      case Empty | CompleteCommit | CompleteAbort =>
-
-        EasyMock.expect(replicaManager.appendRecords(EasyMock.anyLong(),
-          EasyMock.eq((-1).toShort),
-          EasyMock.eq(true),
-          EasyMock.eq(AppendOrigin.Coordinator),
-          EasyMock.eq(recordsByPartition),
-          EasyMock.capture(capturedArgument),
-          EasyMock.anyObject().asInstanceOf[Option[ReentrantLock]],
-          EasyMock.anyObject()
-        )).andAnswer(() => capturedArgument.getValue.apply(
-          Map(partition -> new PartitionResponse(error, 0L, 
RecordBatch.NO_TIMESTAMP, 0L)))
-        )
-      case _ => // shouldn't append
+    for (partitionId <- partitionIds) {
+      EasyMock.expect(replicaManager.getLogConfig(new 
TopicPartition(TRANSACTION_STATE_TOPIC_NAME, partitionId)))
+        .andStubReturn(Some(logConfig))
     }
 
-    EasyMock.replay(replicaManager)
+    EasyMock.replay(logConfig)
+  }
+
+  private def setupAndRunTransactionalIdExpiration(error: Errors, txnState: 
TransactionState): Unit = {
+    val partitionIds = 0 until numPartitions
+
+    loadTransactionsForPartitions(partitionIds)
+    expectLogConfig(partitionIds, Defaults.MaxMessageSize)
 
     txnMetadata1.txnLastUpdateTimestamp = time.milliseconds() - 
txnConfig.transactionalIdExpirationMs
     txnMetadata1.state = txnState
@@ -669,12 +850,28 @@ class TransactionStateManagerTest {
     txnMetadata2.txnLastUpdateTimestamp = time.milliseconds()
     transactionManager.putTransactionStateIfNotExists(txnMetadata2)
 
-    transactionManager.enableTransactionalIdExpiration()
-    time.sleep(txnConfig.removeExpiredTransactionalIdsIntervalMs)
-
-    scheduler.tick()
+    val appendedRecords = mutable.Map.empty[TopicPartition, 
mutable.Buffer[MemoryRecords]]
+    expectTransactionalIdExpiration(error, appendedRecords)
 
+    EasyMock.replay(replicaManager)
+    transactionManager.removeExpiredTransactionalIds()
     EasyMock.verify(replicaManager)
+
+    val stateAllowsExpiration = txnState match {
+      case Empty | CompleteCommit | CompleteAbort => true
+      case _ => false
+    }
+
+    if (stateAllowsExpiration) {
+      val partitionId = transactionManager.partitionFor(transactionalId1)
+      val topicPartition = new TopicPartition(TRANSACTION_STATE_TOPIC_NAME, 
partitionId)
+      val expectedTombstone = new SimpleRecord(time.milliseconds(), 
TransactionLog.keyToBytes(transactionalId1), null)
+      val expectedRecords = 
MemoryRecords.withRecords(TransactionLog.EnforcedCompressionType, 
expectedTombstone)
+      assertEquals(Set(topicPartition), appendedRecords.keySet)
+      assertEquals(Seq(expectedRecords), appendedRecords(topicPartition).toSeq)
+    } else {
+      assertEquals(Map.empty, appendedRecords)
+    }
   }
 
   private def verifyWritesTxnMarkersInPrepareState(state: TransactionState): 
Unit = {

Reply via email to