This is an automated email from the ASF dual-hosted git repository.
jgus pushed a commit to branch 3.0
in repository https://gitbox.apache.org/repos/asf/kafka.git
The following commit(s) were added to refs/heads/3.0 by this push:
new cd233e9 KAFKA-13099; Transactional expiration should account for max
batch size (#11098)
cd233e9 is described below
commit cd233e92b0a17797f3d46b119cb8cf18d049be62
Author: Jason Gustafson <[email protected]>
AuthorDate: Tue Jul 27 18:23:00 2021 -0700
KAFKA-13099; Transactional expiration should account for max batch size
(#11098)
When expiring transactionalIds, we group the tombstones together into
batches. Currently there is no limit on the size of these batches, which can
lead to `MESSAGE_TOO_LARGE` errors when a bunch of transactionalIds need to be
expired at the same time. This patch fixes the problem by ensuring that the
batch size respects the configured limit. Any transactionalIds which are
eligible for expiration and cannot be fit into the batch are postponed until
the next periodic check.
Reviewers: David Jacot <[email protected]>, Guozhang Wang
<[email protected]>
---
.../apache/kafka/common/record/MemoryRecords.java | 14 ++
.../transaction/TransactionMetadata.scala | 5 +
.../transaction/TransactionStateManager.scala | 208 +++++++++++-----
core/src/main/scala/kafka/utils/Pool.scala | 11 +-
.../AbstractCoordinatorConcurrencyTest.scala | 7 +-
.../TransactionCoordinatorConcurrencyTest.scala | 10 +-
.../transaction/TransactionStateManagerTest.scala | 262 ++++++++++++++++++---
7 files changed, 418 insertions(+), 99 deletions(-)
diff --git
a/clients/src/main/java/org/apache/kafka/common/record/MemoryRecords.java
b/clients/src/main/java/org/apache/kafka/common/record/MemoryRecords.java
index 1991759..b631171 100644
--- a/clients/src/main/java/org/apache/kafka/common/record/MemoryRecords.java
+++ b/clients/src/main/java/org/apache/kafka/common/record/MemoryRecords.java
@@ -408,6 +408,20 @@ public class MemoryRecords extends AbstractRecords {
return builder(buffer, RecordBatch.CURRENT_MAGIC_VALUE,
compressionType, timestampType, baseOffset);
}
+ public static MemoryRecordsBuilder builder(ByteBuffer buffer,
+ CompressionType compressionType,
+ TimestampType timestampType,
+ long baseOffset,
+ int maxSize) {
+ long logAppendTime = RecordBatch.NO_TIMESTAMP;
+ if (timestampType == TimestampType.LOG_APPEND_TIME)
+ logAppendTime = System.currentTimeMillis();
+
+ return new MemoryRecordsBuilder(buffer,
RecordBatch.CURRENT_MAGIC_VALUE, compressionType, timestampType, baseOffset,
+ logAppendTime, RecordBatch.NO_PRODUCER_ID,
RecordBatch.NO_PRODUCER_EPOCH, RecordBatch.NO_SEQUENCE,
+ false, false, RecordBatch.NO_PARTITION_LEADER_EPOCH, maxSize);
+ }
+
public static MemoryRecordsBuilder idempotentBuilder(ByteBuffer buffer,
CompressionType
compressionType,
long baseOffset,
diff --git
a/core/src/main/scala/kafka/coordinator/transaction/TransactionMetadata.scala
b/core/src/main/scala/kafka/coordinator/transaction/TransactionMetadata.scala
index 1027468..0f6d4b7 100644
---
a/core/src/main/scala/kafka/coordinator/transaction/TransactionMetadata.scala
+++
b/core/src/main/scala/kafka/coordinator/transaction/TransactionMetadata.scala
@@ -66,6 +66,8 @@ private[transaction] sealed trait TransactionState {
def name: String
def validPreviousStates: Set[TransactionState]
+
+ def isExpirationAllowed: Boolean = false
}
/**
@@ -78,6 +80,7 @@ private[transaction] case object Empty extends
TransactionState {
val id: Byte = 0
val name: String = "Empty"
val validPreviousStates: Set[TransactionState] = Set(Empty, CompleteCommit,
CompleteAbort)
+ override def isExpirationAllowed: Boolean = true
}
/**
@@ -125,6 +128,7 @@ private[transaction] case object CompleteCommit extends
TransactionState {
val id: Byte = 4
val name: String = "CompleteCommit"
val validPreviousStates: Set[TransactionState] = Set(PrepareCommit)
+ override def isExpirationAllowed: Boolean = true
}
/**
@@ -136,6 +140,7 @@ private[transaction] case object CompleteAbort extends
TransactionState {
val id: Byte = 5
val name: String = "CompleteAbort"
val validPreviousStates: Set[TransactionState] = Set(PrepareAbort)
+ override def isExpirationAllowed: Boolean = true
}
/**
diff --git
a/core/src/main/scala/kafka/coordinator/transaction/TransactionStateManager.scala
b/core/src/main/scala/kafka/coordinator/transaction/TransactionStateManager.scala
index 25580f2..217b383 100644
---
a/core/src/main/scala/kafka/coordinator/transaction/TransactionStateManager.scala
+++
b/core/src/main/scala/kafka/coordinator/transaction/TransactionStateManager.scala
@@ -21,6 +21,7 @@ import java.util.Properties
import java.util.concurrent.TimeUnit
import java.util.concurrent.atomic.AtomicBoolean
import java.util.concurrent.locks.ReentrantReadWriteLock
+
import kafka.log.{AppendOrigin, LogConfig}
import kafka.message.UncompressedCodec
import kafka.server.{Defaults, FetchLogEnd, ReplicaManager, RequestLocal}
@@ -32,7 +33,7 @@ import
org.apache.kafka.common.message.ListTransactionsResponseData
import org.apache.kafka.common.metrics.Metrics
import org.apache.kafka.common.metrics.stats.{Avg, Max}
import org.apache.kafka.common.protocol.Errors
-import org.apache.kafka.common.record.{FileRecords, MemoryRecords,
SimpleRecord}
+import org.apache.kafka.common.record.{FileRecords, MemoryRecords,
MemoryRecordsBuilder, Record, SimpleRecord, TimestampType}
import org.apache.kafka.common.requests.ProduceResponse.PartitionResponse
import org.apache.kafka.common.requests.TransactionResult
import org.apache.kafka.common.utils.{Time, Utils}
@@ -140,79 +141,162 @@ class TransactionStateManager(brokerId: Int,
}
}
- def enableTransactionalIdExpiration(): Unit = {
- scheduler.schedule("transactionalId-expiration", () => {
- val now = time.milliseconds()
- inReadLock(stateLock) {
- val transactionalIdByPartition: Map[Int,
mutable.Iterable[TransactionalIdCoordinatorEpochAndMetadata]] =
- transactionMetadataCache.flatMap { case (_, entry) =>
- entry.metadataPerTransactionalId.filter { case (_, txnMetadata) =>
txnMetadata.state match {
- case Empty | CompleteCommit | CompleteAbort => true
- case _ => false
- }
- }.filter { case (_, txnMetadata) =>
- txnMetadata.txnLastUpdateTimestamp <= now -
config.transactionalIdExpirationMs
- }.map { case (transactionalId, txnMetadata) =>
- val txnMetadataTransition = txnMetadata.inLock {
- txnMetadata.prepareDead()
+ private def removeExpiredTransactionalIds(
+ transactionPartition: TopicPartition,
+ txnMetadataCacheEntry: TxnMetadataCacheEntry,
+ ): Unit = {
+ inReadLock(stateLock) {
+ replicaManager.getLogConfig(transactionPartition) match {
+ case Some(logConfig) =>
+ val currentTimeMs = time.milliseconds()
+ val maxBatchSize = logConfig.maxMessageSize
+ val expired =
mutable.ListBuffer.empty[TransactionalIdCoordinatorEpochAndMetadata]
+ var recordsBuilder: MemoryRecordsBuilder = null
+ val stateEntries =
txnMetadataCacheEntry.metadataPerTransactionalId.values.iterator.buffered
+
+ def flushRecordsBuilder(): Unit = {
+ writeTombstonesForExpiredTransactionalIds(
+ transactionPartition,
+ expired.toSeq,
+ recordsBuilder.build()
+ )
+ expired.clear()
+ recordsBuilder = null
+ }
+
+ while (stateEntries.hasNext) {
+ val txnMetadata = stateEntries.head
+ val transactionalId = txnMetadata.transactionalId
+ var fullBatch = false
+
+ txnMetadata.inLock {
+ if (txnMetadata.pendingState.isEmpty &&
shouldExpire(txnMetadata, currentTimeMs)) {
+ if (recordsBuilder == null) {
+ recordsBuilder = MemoryRecords.builder(
+ ByteBuffer.allocate(math.min(16384, maxBatchSize)),
+ TransactionLog.EnforcedCompressionType,
+ TimestampType.CREATE_TIME,
+ 0L,
+ maxBatchSize
+ )
+ }
+
+ if (maybeAppendExpiration(txnMetadata, recordsBuilder,
currentTimeMs)) {
+ val transitMetadata = txnMetadata.prepareDead()
+ expired += TransactionalIdCoordinatorEpochAndMetadata(
+ transactionalId,
+ txnMetadataCacheEntry.coordinatorEpoch,
+ transitMetadata
+ )
+ } else {
+ fullBatch = true
+ }
}
- TransactionalIdCoordinatorEpochAndMetadata(transactionalId,
entry.coordinatorEpoch, txnMetadataTransition)
}
- }.groupBy { transactionalIdCoordinatorEpochAndMetadata =>
-
partitionFor(transactionalIdCoordinatorEpochAndMetadata.transactionalId)
+
+ if (fullBatch) {
+ flushRecordsBuilder()
+ } else {
+ // Advance the iterator if we do not need to retry the append
+ stateEntries.next()
+ }
}
- val recordsPerPartition = transactionalIdByPartition
- .map { case (partition, transactionalIdCoordinatorEpochAndMetadatas)
=>
- val deletes: Array[SimpleRecord] =
transactionalIdCoordinatorEpochAndMetadatas.map { entry =>
- new SimpleRecord(now,
TransactionLog.keyToBytes(entry.transactionalId), null)
- }.toArray
- val records =
MemoryRecords.withRecords(TransactionLog.EnforcedCompressionType, deletes: _*)
- val topicPartition = new
TopicPartition(Topic.TRANSACTION_STATE_TOPIC_NAME, partition)
- (topicPartition, records)
+ if (expired.nonEmpty) {
+ flushRecordsBuilder()
}
- def removeFromCacheCallback(responses: collection.Map[TopicPartition,
PartitionResponse]): Unit = {
- responses.forKeyValue { (topicPartition, response) =>
- inReadLock(stateLock) {
- val toRemove =
transactionalIdByPartition(topicPartition.partition)
- transactionMetadataCache.get(topicPartition.partition).foreach {
txnMetadataCacheEntry =>
- toRemove.foreach { idCoordinatorEpochAndMetadata =>
- val transactionalId =
idCoordinatorEpochAndMetadata.transactionalId
- val txnMetadata =
txnMetadataCacheEntry.metadataPerTransactionalId.get(transactionalId)
- txnMetadata.inLock {
- if (txnMetadataCacheEntry.coordinatorEpoch ==
idCoordinatorEpochAndMetadata.coordinatorEpoch
- && txnMetadata.pendingState.contains(Dead)
- && txnMetadata.producerEpoch ==
idCoordinatorEpochAndMetadata.transitMetadata.producerEpoch
- && response.error == Errors.NONE) {
-
txnMetadataCacheEntry.metadataPerTransactionalId.remove(transactionalId)
- } else {
- warn(s"Failed to remove expired transactionalId:
$transactionalId" +
- s" from cache. Tombstone append error code:
${response.error}," +
- s" pendingState: ${txnMetadata.pendingState},
producerEpoch: ${txnMetadata.producerEpoch}," +
- s" expected producerEpoch:
${idCoordinatorEpochAndMetadata.transitMetadata.producerEpoch}," +
- s" coordinatorEpoch:
${txnMetadataCacheEntry.coordinatorEpoch}, expected coordinatorEpoch: " +
- s"${idCoordinatorEpochAndMetadata.coordinatorEpoch}")
- txnMetadata.pendingState = None
- }
- }
+ case None =>
+ warn(s"Transaction expiration for partition $transactionPartition
failed because the log " +
+ "config was not available, which likely means the partition is not
online or is no longer local.")
+ }
+ }
+ }
+
+ private def shouldExpire(
+ txnMetadata: TransactionMetadata,
+ currentTimeMs: Long
+ ): Boolean = {
+ txnMetadata.state.isExpirationAllowed &&
+ txnMetadata.txnLastUpdateTimestamp <= currentTimeMs -
config.transactionalIdExpirationMs
+ }
+
+ private def maybeAppendExpiration(
+ txnMetadata: TransactionMetadata,
+ recordsBuilder: MemoryRecordsBuilder,
+ currentTimeMs: Long,
+ ): Boolean = {
+ val keyBytes = TransactionLog.keyToBytes(txnMetadata.transactionalId)
+ if (recordsBuilder.hasRoomFor(currentTimeMs, keyBytes, null,
Record.EMPTY_HEADERS)) {
+ recordsBuilder.append(currentTimeMs, keyBytes, null,
Record.EMPTY_HEADERS)
+ true
+ } else {
+ false
+ }
+ }
+
+ private[transaction] def removeExpiredTransactionalIds(): Unit = {
+ inReadLock(stateLock) {
+ transactionMetadataCache.forKeyValue { (partitionId,
partitionCacheEntry) =>
+ val transactionPartition = new
TopicPartition(Topic.TRANSACTION_STATE_TOPIC_NAME, partitionId)
+ removeExpiredTransactionalIds(transactionPartition,
partitionCacheEntry)
+ }
+ }
+ }
+
+ private def writeTombstonesForExpiredTransactionalIds(
+ transactionPartition: TopicPartition,
+ expiredForPartition: Iterable[TransactionalIdCoordinatorEpochAndMetadata],
+ tombstoneRecords: MemoryRecords
+ ): Unit = {
+ def removeFromCacheCallback(responses: collection.Map[TopicPartition,
PartitionResponse]): Unit = {
+ responses.forKeyValue { (topicPartition, response) =>
+ inReadLock(stateLock) {
+ transactionMetadataCache.get(topicPartition.partition).foreach {
txnMetadataCacheEntry =>
+ expiredForPartition.foreach { idCoordinatorEpochAndMetadata =>
+ val transactionalId =
idCoordinatorEpochAndMetadata.transactionalId
+ val txnMetadata =
txnMetadataCacheEntry.metadataPerTransactionalId.get(transactionalId)
+ txnMetadata.inLock {
+ if (txnMetadataCacheEntry.coordinatorEpoch ==
idCoordinatorEpochAndMetadata.coordinatorEpoch
+ && txnMetadata.pendingState.contains(Dead)
+ && txnMetadata.producerEpoch ==
idCoordinatorEpochAndMetadata.transitMetadata.producerEpoch
+ && response.error == Errors.NONE) {
+
txnMetadataCacheEntry.metadataPerTransactionalId.remove(transactionalId)
+ } else {
+ warn(s"Failed to remove expired transactionalId:
$transactionalId" +
+ s" from cache. Tombstone append error code:
${response.error}," +
+ s" pendingState: ${txnMetadata.pendingState},
producerEpoch: ${txnMetadata.producerEpoch}," +
+ s" expected producerEpoch:
${idCoordinatorEpochAndMetadata.transitMetadata.producerEpoch}," +
+ s" coordinatorEpoch:
${txnMetadataCacheEntry.coordinatorEpoch}, expected coordinatorEpoch: " +
+ s"${idCoordinatorEpochAndMetadata.coordinatorEpoch}")
+ txnMetadata.pendingState = None
}
}
}
}
}
-
- replicaManager.appendRecords(
- config.requestTimeoutMs,
- TransactionLog.EnforcedRequiredAcks,
- internalTopicsAllowed = true,
- origin = AppendOrigin.Coordinator,
- recordsPerPartition,
- removeFromCacheCallback,
- requestLocal = RequestLocal.NoCaching)
}
+ }
- }, delay = config.removeExpiredTransactionalIdsIntervalMs, period =
config.removeExpiredTransactionalIdsIntervalMs)
+ inReadLock(stateLock) {
+ replicaManager.appendRecords(
+ config.requestTimeoutMs,
+ TransactionLog.EnforcedRequiredAcks,
+ internalTopicsAllowed = true,
+ origin = AppendOrigin.Coordinator,
+ entriesPerPartition = Map(transactionPartition -> tombstoneRecords),
+ removeFromCacheCallback,
+ requestLocal = RequestLocal.NoCaching)
+ }
+ }
+
+ def enableTransactionalIdExpiration(): Unit = {
+ scheduler.schedule(
+ name = "transactionalId-expiration",
+ fun = removeExpiredTransactionalIds,
+ delay = config.removeExpiredTransactionalIdsIntervalMs,
+ period = config.removeExpiredTransactionalIdsIntervalMs
+ )
}
def getTransactionState(transactionalId: String): Either[Errors,
Option[CoordinatorEpochAndTxnMetadata]] = {
@@ -689,7 +773,7 @@ class TransactionStateManager(brokerId: Int,
}
}
- def startup(retrieveTransactionTopicPartitionCount: () => Int,
enableTransactionalIdExpiration: Boolean = true): Unit = {
+ def startup(retrieveTransactionTopicPartitionCount: () => Int,
enableTransactionalIdExpiration: Boolean): Unit = {
this.retrieveTransactionTopicPartitionCount =
retrieveTransactionTopicPartitionCount
transactionTopicPartitionCount = retrieveTransactionTopicPartitionCount()
if (enableTransactionalIdExpiration)
diff --git a/core/src/main/scala/kafka/utils/Pool.scala
b/core/src/main/scala/kafka/utils/Pool.scala
index d64ff5d..84bedc1 100644
--- a/core/src/main/scala/kafka/utils/Pool.scala
+++ b/core/src/main/scala/kafka/utils/Pool.scala
@@ -80,7 +80,16 @@ class Pool[K,V](valueFactory: Option[K => V] = None) extends
Iterable[(K, V)] {
def foreachEntry(f: (K, V) => Unit): Unit = {
pool.forEach((k, v) => f(k, v))
}
-
+
+ def foreachWhile(f: (K, V) => Boolean): Unit = {
+ val iter = pool.entrySet().iterator()
+ var finished = false
+ while (!finished && iter.hasNext) {
+ val entry = iter.next
+ finished = !f(entry.getKey, entry.getValue)
+ }
+ }
+
override def size: Int = pool.size
override def iterator: Iterator[(K, V)] = new Iterator[(K,V)]() {
diff --git
a/core/src/test/scala/unit/kafka/coordinator/AbstractCoordinatorConcurrencyTest.scala
b/core/src/test/scala/unit/kafka/coordinator/AbstractCoordinatorConcurrencyTest.scala
index 3fee14e..fac34a1 100644
---
a/core/src/test/scala/unit/kafka/coordinator/AbstractCoordinatorConcurrencyTest.scala
+++
b/core/src/test/scala/unit/kafka/coordinator/AbstractCoordinatorConcurrencyTest.scala
@@ -21,8 +21,9 @@ import java.util.concurrent.{ConcurrentHashMap, Executors}
import java.util.{Collections, Random}
import java.util.concurrent.atomic.AtomicInteger
import java.util.concurrent.locks.Lock
+
import kafka.coordinator.AbstractCoordinatorConcurrencyTest._
-import kafka.log.{AppendOrigin, Log}
+import kafka.log.{AppendOrigin, Log, LogConfig}
import kafka.server._
import kafka.utils._
import kafka.utils.timer.MockTimer
@@ -221,6 +222,10 @@ object AbstractCoordinatorConcurrencyTest {
getOrCreateLogs().put(topicPartition, (log, endOffset))
}
+ override def getLogConfig(topicPartition: TopicPartition):
Option[LogConfig] = {
+ getOrCreateLogs().get(topicPartition).map(_._1.config)
+ }
+
override def getLog(topicPartition: TopicPartition): Option[Log] =
getOrCreateLogs().get(topicPartition).map(l => l._1)
diff --git
a/core/src/test/scala/unit/kafka/coordinator/transaction/TransactionCoordinatorConcurrencyTest.scala
b/core/src/test/scala/unit/kafka/coordinator/transaction/TransactionCoordinatorConcurrencyTest.scala
index e02c2fe3..3727a2a 100644
---
a/core/src/test/scala/unit/kafka/coordinator/transaction/TransactionCoordinatorConcurrencyTest.scala
+++
b/core/src/test/scala/unit/kafka/coordinator/transaction/TransactionCoordinatorConcurrencyTest.scala
@@ -17,11 +17,13 @@
package kafka.coordinator.transaction
import java.nio.ByteBuffer
+import java.util.Collections
import java.util.concurrent.atomic.AtomicBoolean
+
import kafka.coordinator.AbstractCoordinatorConcurrencyTest
import kafka.coordinator.AbstractCoordinatorConcurrencyTest._
import kafka.coordinator.transaction.TransactionCoordinatorConcurrencyTest._
-import kafka.log.Log
+import kafka.log.{Log, LogConfig}
import kafka.server.{FetchDataInfo, FetchLogEnd, KafkaConfig,
LogOffsetMetadata, MetadataCache, RequestLocal}
import kafka.utils.{Pool, TestUtils}
import org.apache.kafka.clients.{ClientResponse, NetworkClient}
@@ -72,7 +74,8 @@ class TransactionCoordinatorConcurrencyTest extends
AbstractCoordinatorConcurren
txnStateManager = new TransactionStateManager(0, scheduler,
replicaManager, txnConfig, time,
new Metrics())
- txnStateManager.startup(() =>
zkClient.getTopicPartitionCount(TRANSACTION_STATE_TOPIC_NAME).get)
+ txnStateManager.startup(() =>
zkClient.getTopicPartitionCount(TRANSACTION_STATE_TOPIC_NAME).get,
+ enableTransactionalIdExpiration = true)
for (i <- 0 until numPartitions)
txnStateManager.addLoadedTransactionsToCache(i, coordinatorEpoch, new
Pool[String, TransactionMetadata]())
@@ -455,8 +458,9 @@ class TransactionCoordinatorConcurrencyTest extends
AbstractCoordinatorConcurren
}
private def prepareTxnLog(partitionId: Int): Unit = {
-
val logMock: Log = EasyMock.mock(classOf[Log])
+ EasyMock.expect(logMock.config).andStubReturn(new
LogConfig(Collections.emptyMap()))
+
val fileRecordsMock: FileRecords = EasyMock.mock(classOf[FileRecords])
val topicPartition = new TopicPartition(TRANSACTION_STATE_TOPIC_NAME,
partitionId)
diff --git
a/core/src/test/scala/unit/kafka/coordinator/transaction/TransactionStateManagerTest.scala
b/core/src/test/scala/unit/kafka/coordinator/transaction/TransactionStateManagerTest.scala
index 410d6e2..21629bd 100644
---
a/core/src/test/scala/unit/kafka/coordinator/transaction/TransactionStateManagerTest.scala
+++
b/core/src/test/scala/unit/kafka/coordinator/transaction/TransactionStateManagerTest.scala
@@ -20,8 +20,9 @@ import java.lang.management.ManagementFactory
import java.nio.ByteBuffer
import java.util.concurrent.CountDownLatch
import java.util.concurrent.locks.ReentrantLock
+
import javax.management.ObjectName
-import kafka.log.{AppendOrigin, Log}
+import kafka.log.{AppendOrigin, Defaults, Log, LogConfig}
import kafka.server.{FetchDataInfo, FetchLogEnd, LogOffsetMetadata,
ReplicaManager, RequestLocal}
import kafka.utils.{MockScheduler, Pool, TestUtils}
import kafka.zk.KafkaZkClient
@@ -37,8 +38,8 @@ import org.easymock.{Capture, EasyMock, IAnswer}
import org.junit.jupiter.api.Assertions._
import org.junit.jupiter.api.{AfterEach, BeforeEach, Test}
-import scala.jdk.CollectionConverters._
import scala.collection.{Map, mutable}
+import scala.jdk.CollectionConverters._
class TransactionStateManagerTest {
@@ -78,7 +79,7 @@ class TransactionStateManagerTest {
@BeforeEach
def setUp(): Unit = {
- transactionManager.startup(() => numPartitions, false)
+ transactionManager.startup(() => numPartitions,
enableTransactionalIdExpiration = false)
// make sure the transactional id hashes to the assigning partition id
assertEquals(partitionId,
transactionManager.partitionFor(transactionalId1))
assertEquals(partitionId,
transactionManager.partitionFor(transactionalId2))
@@ -581,7 +582,7 @@ class TransactionStateManagerTest {
}
@Test
- def shouldRemoveCompleteCommmitExpiredTransactionalIds(): Unit = {
+ def shouldRemoveCompleteCommitExpiredTransactionalIds(): Unit = {
setupAndRunTransactionalIdExpiration(Errors.NONE, CompleteCommit)
verifyMetadataDoesntExist(transactionalId1)
verifyMetadataDoesExistAndIsUsable(transactionalId2)
@@ -630,6 +631,157 @@ class TransactionStateManagerTest {
}
@Test
+ def testTransactionalExpirationWithTooSmallBatchSize(): Unit = {
+ // The batch size is too small, but we nevertheless expect the
+ // coordinator to attempt the append. This test mainly ensures
+ // that the expiration task does not get stuck.
+
+ val partitionIds = 0 until numPartitions
+ val maxBatchSize = 16
+
+ loadTransactionsForPartitions(partitionIds)
+ val allTransactionalIds = loadExpiredTransactionalIds(numTransactionalIds
= 20)
+
+ EasyMock.reset(replicaManager)
+ expectLogConfig(partitionIds, maxBatchSize)
+
+ val attemptedAppends = mutable.Map.empty[TopicPartition,
mutable.Buffer[MemoryRecords]]
+ expectTransactionalIdExpiration(Errors.MESSAGE_TOO_LARGE, attemptedAppends)
+ EasyMock.replay(replicaManager)
+
+ assertEquals(allTransactionalIds, listExpirableTransactionalIds())
+ transactionManager.removeExpiredTransactionalIds()
+ EasyMock.verify(replicaManager)
+
+ for (batches <- attemptedAppends.values; batch <- batches) {
+ assertTrue(batch.sizeInBytes() > maxBatchSize)
+ }
+
+ assertEquals(allTransactionalIds, listExpirableTransactionalIds())
+ }
+
+ @Test
+ def testTransactionalExpirationWithOfflineLogDir(): Unit = {
+ val onlinePartitionId = 0
+ val offlinePartitionId = 1
+
+ val partitionIds = Seq(onlinePartitionId, offlinePartitionId)
+ val maxBatchSize = 512
+
+ loadTransactionsForPartitions(partitionIds)
+ val allTransactionalIds = loadExpiredTransactionalIds(numTransactionalIds
= 20)
+
+ EasyMock.reset(replicaManager)
+
+ // Partition 0 returns log config as normal
+ expectLogConfig(Seq(onlinePartitionId), maxBatchSize)
+ // No log config returned for partition 0 since it is offline
+ EasyMock.expect(replicaManager.getLogConfig(new
TopicPartition(TRANSACTION_STATE_TOPIC_NAME, offlinePartitionId)))
+ .andStubReturn(None)
+
+ val appendedRecords = mutable.Map.empty[TopicPartition,
mutable.Buffer[MemoryRecords]]
+ expectTransactionalIdExpiration(Errors.NONE, appendedRecords)
+ EasyMock.replay(replicaManager)
+
+ assertEquals(allTransactionalIds, listExpirableTransactionalIds())
+ transactionManager.removeExpiredTransactionalIds()
+ EasyMock.verify(replicaManager)
+
+ assertEquals(Set(onlinePartitionId),
appendedRecords.keySet.map(_.partition))
+
+ val (transactionalIdsForOnlinePartition,
transactionalIdsForOfflinePartition) =
+ allTransactionalIds.partition { transactionalId =>
+ transactionManager.partitionFor(transactionalId) == onlinePartitionId
+ }
+
+ val expiredTransactionalIds =
collectTransactionalIdsFromTombstones(appendedRecords)
+ assertEquals(transactionalIdsForOnlinePartition, expiredTransactionalIds)
+ assertEquals(transactionalIdsForOfflinePartition,
listExpirableTransactionalIds())
+ }
+
+ @Test
+ def testTransactionExpirationShouldRespectBatchSize(): Unit = {
+ val partitionIds = 0 until numPartitions
+ val maxBatchSize = 512
+
+ loadTransactionsForPartitions(partitionIds)
+ val allTransactionalIds = loadExpiredTransactionalIds(numTransactionalIds
= 1000)
+
+ EasyMock.reset(replicaManager)
+ expectLogConfig(partitionIds, maxBatchSize)
+
+ val appendedRecords = mutable.Map.empty[TopicPartition,
mutable.Buffer[MemoryRecords]]
+ expectTransactionalIdExpiration(Errors.NONE, appendedRecords)
+ EasyMock.replay(replicaManager)
+
+ assertEquals(allTransactionalIds, listExpirableTransactionalIds())
+ transactionManager.removeExpiredTransactionalIds()
+ EasyMock.verify(replicaManager)
+
+ assertEquals(Set.empty, listExpirableTransactionalIds())
+ assertEquals(partitionIds.toSet, appendedRecords.keys.map(_.partition))
+
+ appendedRecords.values.foreach { batches =>
+ assertTrue(batches.size > 1) // Ensure a non-trivial test case
+ assertTrue(batches.forall(_.sizeInBytes() < maxBatchSize))
+ }
+
+ val expiredTransactionalIds =
collectTransactionalIdsFromTombstones(appendedRecords)
+ assertEquals(allTransactionalIds, expiredTransactionalIds)
+ }
+
+ private def collectTransactionalIdsFromTombstones(
+ appendedRecords: mutable.Map[TopicPartition, mutable.Buffer[MemoryRecords]]
+ ): Set[String] = {
+ val expiredTransactionalIds = mutable.Set.empty[String]
+ appendedRecords.values.foreach { batches =>
+ batches.foreach { records =>
+ records.records.forEach { record =>
+ val transactionalId =
TransactionLog.readTxnRecordKey(record.key).transactionalId
+ assertNull(record.value)
+ expiredTransactionalIds += transactionalId
+ assertEquals(Right(None),
transactionManager.getTransactionState(transactionalId))
+ }
+ }
+ }
+ expiredTransactionalIds.toSet
+ }
+
+ private def loadExpiredTransactionalIds(
+ numTransactionalIds: Int
+ ): Set[String] = {
+ val allTransactionalIds = mutable.Set.empty[String]
+ for (i <- 0 to numTransactionalIds) {
+ val txnlId = s"id_$i"
+ val producerId = i
+ val txnMetadata = transactionMetadata(txnlId, producerId)
+ txnMetadata.txnLastUpdateTimestamp = time.milliseconds() -
txnConfig.transactionalIdExpirationMs
+ transactionManager.putTransactionStateIfNotExists(txnMetadata)
+ allTransactionalIds += txnlId
+ }
+ allTransactionalIds.toSet
+ }
+
+ private def listExpirableTransactionalIds(): Set[String] = {
+ val activeTransactionalIds =
transactionManager.listTransactionStates(Set.empty, Set.empty)
+ .transactionStates
+ .asScala
+ .map(_.transactionalId)
+
+ activeTransactionalIds.filter { transactionalId =>
+ transactionManager.getTransactionState(transactionalId) match {
+ case Right(Some(epochAndMetadata)) =>
+ val txnMetadata = epochAndMetadata.transactionMetadata
+ val timeSinceLastUpdate = time.milliseconds() -
txnMetadata.txnLastUpdateTimestamp
+ timeSinceLastUpdate >= txnConfig.transactionalIdExpirationMs &&
+ txnMetadata.state.isExpirationAllowed &&
+ txnMetadata.pendingState.isEmpty
+ case _ => false
+ }
+ }.toSet
+ }
+
+ @Test
def testSuccessfulReimmigration(): Unit = {
txnMetadata1.state = PrepareCommit
txnMetadata1.addPartitions(Set[TopicPartition](new
TopicPartition("topic1", 0),
@@ -701,36 +853,66 @@ class TransactionStateManagerTest {
}
}
- private def setupAndRunTransactionalIdExpiration(error: Errors, txnState:
TransactionState): Unit = {
- for (partitionId <- 0 until numPartitions) {
+ private def expectTransactionalIdExpiration(
+ appendError: Errors,
+ capturedAppends: mutable.Map[TopicPartition, mutable.Buffer[MemoryRecords]]
+ ): Unit = {
+ val recordsCapture: Capture[Map[TopicPartition, MemoryRecords]] =
EasyMock.newCapture()
+ val callbackCapture: Capture[Map[TopicPartition, PartitionResponse] =>
Unit] = EasyMock.newCapture()
+
+ EasyMock.expect(replicaManager.appendRecords(
+ EasyMock.anyLong(),
+ EasyMock.eq((-1).toShort),
+ EasyMock.eq(true),
+ EasyMock.eq(AppendOrigin.Coordinator),
+ EasyMock.capture(recordsCapture),
+ EasyMock.capture(callbackCapture),
+ EasyMock.anyObject().asInstanceOf[Option[ReentrantLock]],
+ EasyMock.anyObject(),
+ EasyMock.anyObject()
+ )).andAnswer(() => callbackCapture.getValue.apply(
+ recordsCapture.getValue.map { case (topicPartition, records) =>
+ val batches = capturedAppends.getOrElse(topicPartition, {
+ val batches = mutable.Buffer.empty[MemoryRecords]
+ capturedAppends += topicPartition -> batches
+ batches
+ })
+
+ batches += records
+
+ topicPartition -> new PartitionResponse(appendError, 0L,
RecordBatch.NO_TIMESTAMP, 0L)
+ }.toMap
+ )).anyTimes()
+ }
+
+ private def loadTransactionsForPartitions(
+ partitionIds: Seq[Int],
+ ): Unit = {
+ for (partitionId <- partitionIds) {
transactionManager.addLoadedTransactionsToCache(partitionId, 0, new
Pool[String, TransactionMetadata]())
}
+ }
- val capturedArgument: Capture[Map[TopicPartition, PartitionResponse] =>
Unit] = EasyMock.newCapture()
+ private def expectLogConfig(
+ partitionIds: Seq[Int],
+ maxBatchSize: Int
+ ): Unit = {
+ val logConfig: LogConfig = EasyMock.mock(classOf[LogConfig])
+ EasyMock.expect(logConfig.maxMessageSize).andStubReturn(maxBatchSize)
- val partition = new TopicPartition(TRANSACTION_STATE_TOPIC_NAME,
transactionManager.partitionFor(transactionalId1))
- val recordsByPartition = Map(partition ->
MemoryRecords.withRecords(TransactionLog.EnforcedCompressionType,
- new SimpleRecord(time.milliseconds() +
txnConfig.removeExpiredTransactionalIdsIntervalMs,
TransactionLog.keyToBytes(transactionalId1), null)))
-
- txnState match {
- case Empty | CompleteCommit | CompleteAbort =>
-
- EasyMock.expect(replicaManager.appendRecords(EasyMock.anyLong(),
- EasyMock.eq((-1).toShort),
- EasyMock.eq(true),
- EasyMock.eq(AppendOrigin.Coordinator),
- EasyMock.eq(recordsByPartition),
- EasyMock.capture(capturedArgument),
- EasyMock.anyObject().asInstanceOf[Option[ReentrantLock]],
- EasyMock.anyObject(),
- EasyMock.anyObject()
- )).andAnswer(() => capturedArgument.getValue.apply(
- Map(partition -> new PartitionResponse(error, 0L,
RecordBatch.NO_TIMESTAMP, 0L)))
- )
- case _ => // shouldn't append
+ for (partitionId <- partitionIds) {
+ EasyMock.expect(replicaManager.getLogConfig(new
TopicPartition(TRANSACTION_STATE_TOPIC_NAME, partitionId)))
+ .andStubReturn(Some(logConfig))
}
- EasyMock.replay(replicaManager)
+ EasyMock.replay(logConfig)
+ }
+
+ private def setupAndRunTransactionalIdExpiration(error: Errors, txnState:
TransactionState): Unit = {
+ val partitionIds = 0 until numPartitions
+
+ loadTransactionsForPartitions(partitionIds)
+ expectLogConfig(partitionIds, Defaults.MaxMessageSize)
txnMetadata1.txnLastUpdateTimestamp = time.milliseconds() -
txnConfig.transactionalIdExpirationMs
txnMetadata1.state = txnState
@@ -739,12 +921,28 @@ class TransactionStateManagerTest {
txnMetadata2.txnLastUpdateTimestamp = time.milliseconds()
transactionManager.putTransactionStateIfNotExists(txnMetadata2)
- transactionManager.enableTransactionalIdExpiration()
- time.sleep(txnConfig.removeExpiredTransactionalIdsIntervalMs)
-
- scheduler.tick()
+ val appendedRecords = mutable.Map.empty[TopicPartition,
mutable.Buffer[MemoryRecords]]
+ expectTransactionalIdExpiration(error, appendedRecords)
+ EasyMock.replay(replicaManager)
+ transactionManager.removeExpiredTransactionalIds()
EasyMock.verify(replicaManager)
+
+ val stateAllowsExpiration = txnState match {
+ case Empty | CompleteCommit | CompleteAbort => true
+ case _ => false
+ }
+
+ if (stateAllowsExpiration) {
+ val partitionId = transactionManager.partitionFor(transactionalId1)
+ val topicPartition = new TopicPartition(TRANSACTION_STATE_TOPIC_NAME,
partitionId)
+ val expectedTombstone = new SimpleRecord(time.milliseconds(),
TransactionLog.keyToBytes(transactionalId1), null)
+ val expectedRecords =
MemoryRecords.withRecords(TransactionLog.EnforcedCompressionType,
expectedTombstone)
+ assertEquals(Set(topicPartition), appendedRecords.keySet)
+ assertEquals(Seq(expectedRecords), appendedRecords(topicPartition).toSeq)
+ } else {
+ assertEquals(Map.empty, appendedRecords)
+ }
}
private def verifyWritesTxnMarkersInPrepareState(state: TransactionState):
Unit = {