This is an automated email from the ASF dual-hosted git repository.

jgus pushed a commit to branch trunk
in repository https://gitbox.apache.org/repos/asf/kafka.git


The following commit(s) were added to refs/heads/trunk by this push:
     new 1dbcb7da9e3 KAFKA-14694: RPCProducerIdManager should not wait on new 
block (#13267)
1dbcb7da9e3 is described below

commit 1dbcb7da9e3625ec2078a82f84542a3127730fef
Author: Jeff Kim <kimkb2...@gmail.com>
AuthorDate: Thu Jun 22 13:19:39 2023 -0400

    KAFKA-14694: RPCProducerIdManager should not wait on new block (#13267)
    
    RPCProducerIdManager initiates an async request to the controller to grab a 
block of producer IDs and then blocks waiting for a response from the 
controller.
    
    This is done in the request handler threads while holding a global lock. 
This means that if many producers are requesting producer IDs and the 
controller is slow to respond, many threads can get stuck waiting for the lock.
    
    This patch aims to:
    * resolve the deadlock scenario mentioned above by not waiting for a new 
block and returning an error immediately
    * remove synchronization usages in RpcProducerIdManager.generateProducerId()
    * handle errors returned from generateProducerId() so that KafkaApis does 
not log unexpected errors
    * confirm producers backoff before retrying
    * introduce backoff if manager fails to process AllocateProducerIdsResponse
    
    Reviewers: Artem Livshits <alivsh...@confluent.io>, Jason Gustafson 
<ja...@confluent.io>
---
 .../clients/producer/internals/SenderTest.java     |  24 ++-
 .../transaction/ProducerIdManager.scala            | 167 +++++++++++--------
 .../transaction/TransactionCoordinator.scala       |  47 ++++--
 .../src/main/scala/kafka/server/BrokerServer.scala |   4 +-
 core/src/main/scala/kafka/server/KafkaServer.scala |   4 +-
 .../transaction/ProducerIdsIntegrationTest.scala   |  66 ++++++--
 .../transaction/ProducerIdManagerTest.scala        | 180 ++++++++++++++++-----
 .../TransactionCoordinatorConcurrencyTest.scala    |   7 +-
 .../transaction/TransactionCoordinatorTest.scala   |  11 +-
 .../AddPartitionsToTxnRequestServerTest.scala      |   9 +-
 .../kafka/server/common/ProducerIdsBlock.java      |  16 ++
 .../kafka/server/common/ProducerIdsBlockTest.java  |  25 +++
 12 files changed, 414 insertions(+), 146 deletions(-)

diff --git 
a/clients/src/test/java/org/apache/kafka/clients/producer/internals/SenderTest.java
 
b/clients/src/test/java/org/apache/kafka/clients/producer/internals/SenderTest.java
index f6c91659356..b80817465a5 100644
--- 
a/clients/src/test/java/org/apache/kafka/clients/producer/internals/SenderTest.java
+++ 
b/clients/src/test/java/org/apache/kafka/clients/producer/internals/SenderTest.java
@@ -3035,6 +3035,28 @@ public class SenderTest {
         verifyErrorMessage(produceResponse(tp0, 0L, Errors.INVALID_REQUEST, 0, 
-1, errorMessage), errorMessage);
     }
 
+    @Test
+    public void testSenderShouldRetryWithBackoffOnRetriableError() {
+        final long producerId = 343434L;
+        TransactionManager transactionManager = createTransactionManager();
+        setupWithTransactionState(transactionManager);
+        long start = time.milliseconds();
+
+        // first request is sent immediately
+        prepareAndReceiveInitProducerId(producerId, (short) -1, 
Errors.COORDINATOR_LOAD_IN_PROGRESS);
+        long request1 = time.milliseconds();
+        assertEquals(start, request1);
+
+        // backoff before sending second request
+        prepareAndReceiveInitProducerId(producerId, (short) -1, 
Errors.COORDINATOR_LOAD_IN_PROGRESS);
+        long request2 = time.milliseconds();
+        assertEquals(RETRY_BACKOFF_MS, request2 - request1);
+
+        // third request should also backoff
+        prepareAndReceiveInitProducerId(producerId, Errors.NONE);
+        assertEquals(RETRY_BACKOFF_MS, time.milliseconds() - request2);
+    }
+
     private void verifyErrorMessage(ProduceResponse response, String 
expectedMessage) throws Exception {
         Future<RecordMetadata> future = appendToAccumulator(tp0, 0L, "key", 
"value");
         sender.runOnce(); // connect
@@ -3191,7 +3213,7 @@ public class SenderTest {
     }
 
     private TransactionManager createTransactionManager() {
-        return new TransactionManager(new LogContext(), null, 0, 100L, new 
ApiVersions());
+        return new TransactionManager(new LogContext(), null, 0, 
RETRY_BACKOFF_MS, new ApiVersions());
     }
     
     private void setupWithTransactionState(TransactionManager 
transactionManager) {
diff --git 
a/core/src/main/scala/kafka/coordinator/transaction/ProducerIdManager.scala 
b/core/src/main/scala/kafka/coordinator/transaction/ProducerIdManager.scala
index f16785a7b6c..1e2b6ffac5a 100644
--- a/core/src/main/scala/kafka/coordinator/transaction/ProducerIdManager.scala
+++ b/core/src/main/scala/kafka/coordinator/transaction/ProducerIdManager.scala
@@ -16,6 +16,7 @@
  */
 package kafka.coordinator.transaction
 
+import kafka.coordinator.transaction.ProducerIdManager.{IterationLimit, 
NoRetry, RetryBackoffMs}
 import kafka.server.{BrokerToControllerChannelManager, 
ControllerRequestCompletionHandler}
 import kafka.utils.Logging
 import kafka.zk.{KafkaZkClient, ProducerIdBlockZNode}
@@ -24,10 +25,11 @@ import org.apache.kafka.common.KafkaException
 import org.apache.kafka.common.message.AllocateProducerIdsRequestData
 import org.apache.kafka.common.protocol.Errors
 import org.apache.kafka.common.requests.{AllocateProducerIdsRequest, 
AllocateProducerIdsResponse}
+import org.apache.kafka.common.utils.Time
 import org.apache.kafka.server.common.ProducerIdsBlock
 
-import java.util.concurrent.{ArrayBlockingQueue, TimeUnit}
-import java.util.concurrent.atomic.AtomicBoolean
+import java.util.concurrent.atomic.{AtomicBoolean, AtomicLong, AtomicReference}
+import scala.compat.java8.OptionConverters.RichOptionalGeneric
 import scala.util.{Failure, Success, Try}
 
 /**
@@ -41,6 +43,9 @@ import scala.util.{Failure, Success, Try}
 object ProducerIdManager {
   // Once we reach this percentage of PIDs consumed from the current block, 
trigger a fetch of the next block
   val PidPrefetchThreshold = 0.90
+  val IterationLimit = 3
+  val RetryBackoffMs = 50
+  val NoRetry = -1L
 
   // Creates a ProducerIdGenerate that directly interfaces with ZooKeeper, IBP 
< 3.0-IV0
   def zk(brokerId: Int, zkClient: KafkaZkClient): ZkProducerIdManager = {
@@ -49,16 +54,20 @@ object ProducerIdManager {
 
   // Creates a ProducerIdGenerate that uses AllocateProducerIds RPC, IBP >= 
3.0-IV0
   def rpc(brokerId: Int,
-            brokerEpochSupplier: () => Long,
-            controllerChannel: BrokerToControllerChannelManager,
-            maxWaitMs: Int): RPCProducerIdManager = {
-    new RPCProducerIdManager(brokerId, brokerEpochSupplier, controllerChannel, 
maxWaitMs)
+          time: Time,
+          brokerEpochSupplier: () => Long,
+          controllerChannel: BrokerToControllerChannelManager): 
RPCProducerIdManager = {
+
+    new RPCProducerIdManager(brokerId, time, brokerEpochSupplier, 
controllerChannel)
   }
 }
 
 trait ProducerIdManager {
-  def generateProducerId(): Long
+  def generateProducerId(): Try[Long]
   def shutdown() : Unit = {}
+
+  // For testing purposes
+  def hasValidBlock: Boolean
 }
 
 object ZkProducerIdManager {
@@ -103,8 +112,7 @@ object ZkProducerIdManager {
   }
 }
 
-class ZkProducerIdManager(brokerId: Int,
-                          zkClient: KafkaZkClient) extends ProducerIdManager 
with Logging {
+class ZkProducerIdManager(brokerId: Int, zkClient: KafkaZkClient) extends 
ProducerIdManager with Logging {
 
   this.logIdent = "[ZK ProducerId Manager " + brokerId + "]: "
 
@@ -123,73 +131,103 @@ class ZkProducerIdManager(brokerId: Int,
     }
   }
 
-  def generateProducerId(): Long = {
+  def generateProducerId(): Try[Long] = {
     this synchronized {
       // grab a new block of producerIds if this block has been exhausted
       if (nextProducerId > currentProducerIdBlock.lastProducerId) {
-        allocateNewProducerIdBlock()
+        try {
+          allocateNewProducerIdBlock()
+        } catch {
+          case t: Throwable =>
+            return Failure(t)
+        }
         nextProducerId = currentProducerIdBlock.firstProducerId
       }
       nextProducerId += 1
-      nextProducerId - 1
+      Success(nextProducerId - 1)
+    }
+  }
+
+  override def hasValidBlock: Boolean = {
+    this synchronized {
+      !currentProducerIdBlock.equals(ProducerIdsBlock.EMPTY)
     }
   }
 }
 
+/**
+ * RPCProducerIdManager allocates producer id blocks asynchronously and will 
immediately fail requests
+ * for producers to retry if it does not have an available producer id and is 
waiting on a new block.
+ */
 class RPCProducerIdManager(brokerId: Int,
+                           time: Time,
                            brokerEpochSupplier: () => Long,
-                           controllerChannel: BrokerToControllerChannelManager,
-                           maxWaitMs: Int) extends ProducerIdManager with 
Logging {
+                           controllerChannel: 
BrokerToControllerChannelManager) extends ProducerIdManager with Logging {
 
   this.logIdent = "[RPC ProducerId Manager " + brokerId + "]: "
 
-  private val nextProducerIdBlock = new 
ArrayBlockingQueue[Try[ProducerIdsBlock]](1)
+  // Visible for testing
+  private[transaction] var nextProducerIdBlock = new 
AtomicReference[ProducerIdsBlock](null)
+  private val currentProducerIdBlock: AtomicReference[ProducerIdsBlock] = new 
AtomicReference(ProducerIdsBlock.EMPTY)
   private val requestInFlight = new AtomicBoolean(false)
-  private var currentProducerIdBlock: ProducerIdsBlock = ProducerIdsBlock.EMPTY
-  private var nextProducerId: Long = -1L
+  private val backoffDeadlineMs = new AtomicLong(NoRetry)
 
-  override def generateProducerId(): Long = {
-    this synchronized {
-      if (nextProducerId == -1L) {
-        // Send an initial request to get the first block
-        maybeRequestNextBlock()
-        nextProducerId = 0L
-      } else {
-        nextProducerId += 1
-
-        // Check if we need to fetch the next block
-        if (nextProducerId >= (currentProducerIdBlock.firstProducerId + 
currentProducerIdBlock.size * ProducerIdManager.PidPrefetchThreshold)) {
-          maybeRequestNextBlock()
-        }
-      }
+  override def hasValidBlock: Boolean = {
+    nextProducerIdBlock.get != null
+  }
 
-      // If we've exhausted the current block, grab the next block (waiting if 
necessary)
-      if (nextProducerId > currentProducerIdBlock.lastProducerId) {
-        val block = nextProducerIdBlock.poll(maxWaitMs, TimeUnit.MILLISECONDS)
-        if (block == null) {
-          // Return COORDINATOR_LOAD_IN_PROGRESS rather than REQUEST_TIMED_OUT 
since older clients treat the error as fatal
-          // when it should be retriable like COORDINATOR_LOAD_IN_PROGRESS.
-          throw Errors.COORDINATOR_LOAD_IN_PROGRESS.exception("Timed out 
waiting for next producer ID block")
-        } else {
-          block match {
-            case Success(nextBlock) =>
-              currentProducerIdBlock = nextBlock
-              nextProducerId = currentProducerIdBlock.firstProducerId
-            case Failure(t) => throw t
+  override def generateProducerId(): Try[Long] = {
+    var result: Try[Long] = null
+    var iteration = 0
+    while (result == null) {
+      currentProducerIdBlock.get.claimNextId().asScala match {
+        case None =>
+          // Check the next block if current block is full
+          val block = nextProducerIdBlock.getAndSet(null)
+          if (block == null) {
+            // Return COORDINATOR_LOAD_IN_PROGRESS rather than 
REQUEST_TIMED_OUT since older clients treat the error as fatal
+            // when it should be retriable like COORDINATOR_LOAD_IN_PROGRESS.
+            maybeRequestNextBlock()
+            result = 
Failure(Errors.COORDINATOR_LOAD_IN_PROGRESS.exception("Producer ID block is 
full. Waiting for next block"))
+          } else {
+            currentProducerIdBlock.set(block)
+            requestInFlight.set(false)
+            iteration = iteration + 1
           }
-        }
+
+        case Some(nextProducerId) =>
+          // Check if we need to prefetch the next block
+          val prefetchTarget = currentProducerIdBlock.get.firstProducerId + 
(currentProducerIdBlock.get.size * 
ProducerIdManager.PidPrefetchThreshold).toLong
+          if (nextProducerId == prefetchTarget) {
+            maybeRequestNextBlock()
+          }
+          result = Success(nextProducerId)
+      }
+      if (iteration == IterationLimit) {
+        result = 
Failure(Errors.COORDINATOR_LOAD_IN_PROGRESS.exception("Producer ID block is 
full. Waiting for next block"))
       }
-      nextProducerId
     }
+    result
   }
 
 
-  private def maybeRequestNextBlock(): Unit = {
-    if (nextProducerIdBlock.isEmpty && requestInFlight.compareAndSet(false, 
true)) {
-      sendRequest()
+  // Visible for testing
+  private[transaction] def maybeRequestNextBlock(): Unit = {
+    val retryTimestamp = backoffDeadlineMs.get()
+    if (retryTimestamp == NoRetry || time.milliseconds() >= retryTimestamp) {
+      // Send a request only if we reached the retry deadline, or if no 
deadline was set.
+
+      if (nextProducerIdBlock.get == null &&
+        requestInFlight.compareAndSet(false, true) ) {
+
+        sendRequest()
+        // Reset backoff after a successful send.
+        backoffDeadlineMs.set(NoRetry)
+      }
     }
   }
 
+  // Visible for testing
   private[transaction] def sendRequest(): Unit = {
     val message = new AllocateProducerIdsRequestData()
       .setBrokerEpoch(brokerEpochSupplier.apply())
@@ -207,37 +245,40 @@ class RPCProducerIdManager(brokerId: Int,
     })
   }
 
+  // Visible for testing
   private[transaction] def handleAllocateProducerIdsResponse(response: 
AllocateProducerIdsResponse): Unit = {
-    requestInFlight.set(false)
     val data = response.data
+    var successfulResponse = false
     Errors.forCode(data.errorCode()) match {
       case Errors.NONE =>
         debug(s"Got next producer ID block from controller $data")
         // Do some sanity checks on the response
-        if (data.producerIdStart() < currentProducerIdBlock.lastProducerId) {
-          nextProducerIdBlock.put(Failure(new KafkaException(
-            s"Producer ID block is not monotonic with current block: 
current=$currentProducerIdBlock response=$data")))
+        if (data.producerIdStart() < 
currentProducerIdBlock.get.lastProducerId) {
+          error(s"Producer ID block is not monotonic with current block: 
current=$currentProducerIdBlock response=$data")
         } else if (data.producerIdStart() < 0 || data.producerIdLen() < 0 || 
data.producerIdStart() > Long.MaxValue - data.producerIdLen()) {
-          nextProducerIdBlock.put(Failure(new KafkaException(s"Producer ID 
block includes invalid ID range: $data")))
+          error(s"Producer ID block includes invalid ID range: $data")
         } else {
-          nextProducerIdBlock.put(
-            Success(new ProducerIdsBlock(brokerId, data.producerIdStart(), 
data.producerIdLen())))
+          nextProducerIdBlock.set(new ProducerIdsBlock(brokerId, 
data.producerIdStart(), data.producerIdLen()))
+          successfulResponse = true
         }
       case Errors.STALE_BROKER_EPOCH =>
-        warn("Our broker epoch was stale, trying again.")
-        maybeRequestNextBlock()
+        warn("Our broker currentBlockCount was stale, trying again.")
       case Errors.BROKER_ID_NOT_REGISTERED =>
         warn("Our broker ID is not yet known by the controller, trying again.")
-        maybeRequestNextBlock()
       case e: Errors =>
-        warn("Had an unknown error from the controller, giving up.")
-        nextProducerIdBlock.put(Failure(e.exception()))
+        error(s"Received an unexpected error code from the controller: $e")
+    }
+
+    if (!successfulResponse) {
+      // There is no need to compare and set because only one thread
+      // handles the AllocateProducerIds response.
+      backoffDeadlineMs.set(time.milliseconds() + RetryBackoffMs)
+      requestInFlight.set(false)
     }
   }
 
-  private[transaction] def handleTimeout(): Unit = {
+  private def handleTimeout(): Unit = {
     warn("Timed out when requesting AllocateProducerIds from the controller.")
     requestInFlight.set(false)
-    maybeRequestNextBlock()
   }
 }
diff --git 
a/core/src/main/scala/kafka/coordinator/transaction/TransactionCoordinator.scala
 
b/core/src/main/scala/kafka/coordinator/transaction/TransactionCoordinator.scala
index bb1b3792c83..7eda8f3b1f2 100644
--- 
a/core/src/main/scala/kafka/coordinator/transaction/TransactionCoordinator.scala
+++ 
b/core/src/main/scala/kafka/coordinator/transaction/TransactionCoordinator.scala
@@ -32,6 +32,7 @@ import org.apache.kafka.common.utils.{LogContext, 
ProducerIdAndEpoch, Time}
 import org.apache.kafka.server.util.Scheduler
 
 import scala.jdk.CollectionConverters._
+import scala.util.{Failure, Success}
 
 object TransactionCoordinator {
 
@@ -113,8 +114,12 @@ class TransactionCoordinator(txnConfig: TransactionConfig,
     if (transactionalId == null) {
       // if the transactional id is null, then always blindly accept the 
request
       // and return a new producerId from the producerId manager
-      val producerId = producerIdManager.generateProducerId()
-      responseCallback(InitProducerIdResult(producerId, producerEpoch = 0, 
Errors.NONE))
+      producerIdManager.generateProducerId() match {
+        case Success(producerId) =>
+          responseCallback(InitProducerIdResult(producerId, producerEpoch = 0, 
Errors.NONE))
+        case Failure(exception) =>
+          
responseCallback(initTransactionError(Errors.forException(exception)))
+      }
     } else if (transactionalId.isEmpty) {
       // if transactional id is empty then return error as invalid request. 
This is
       // to make TransactionCoordinator's behavior consistent with producer 
client
@@ -125,17 +130,22 @@ class TransactionCoordinator(txnConfig: TransactionConfig,
     } else {
       val coordinatorEpochAndMetadata = 
txnManager.getTransactionState(transactionalId).flatMap {
         case None =>
-          val producerId = producerIdManager.generateProducerId()
-          val createdMetadata = new TransactionMetadata(transactionalId = 
transactionalId,
-            producerId = producerId,
-            lastProducerId = RecordBatch.NO_PRODUCER_ID,
-            producerEpoch = RecordBatch.NO_PRODUCER_EPOCH,
-            lastProducerEpoch = RecordBatch.NO_PRODUCER_EPOCH,
-            txnTimeoutMs = transactionTimeoutMs,
-            state = Empty,
-            topicPartitions = collection.mutable.Set.empty[TopicPartition],
-            txnLastUpdateTimestamp = time.milliseconds())
-          txnManager.putTransactionStateIfNotExists(createdMetadata)
+          producerIdManager.generateProducerId() match {
+            case Success(producerId) =>
+              val createdMetadata = new TransactionMetadata(transactionalId = 
transactionalId,
+                producerId = producerId,
+                lastProducerId = RecordBatch.NO_PRODUCER_ID,
+                producerEpoch = RecordBatch.NO_PRODUCER_EPOCH,
+                lastProducerEpoch = RecordBatch.NO_PRODUCER_EPOCH,
+                txnTimeoutMs = transactionTimeoutMs,
+                state = Empty,
+                topicPartitions = collection.mutable.Set.empty[TopicPartition],
+                txnLastUpdateTimestamp = time.milliseconds())
+              txnManager.putTransactionStateIfNotExists(createdMetadata)
+
+            case Failure(exception) =>
+              Left(Errors.forException(exception))
+          }
 
         case Some(epochAndTxnMetadata) => Right(epochAndTxnMetadata)
       }
@@ -231,9 +241,14 @@ class TransactionCoordinator(txnConfig: TransactionConfig,
             // If the epoch is exhausted and the expected epoch (if provided) 
matches it, generate a new producer ID
             if (txnMetadata.isProducerEpochExhausted &&
                 expectedProducerIdAndEpoch.forall(_.epoch == 
txnMetadata.producerEpoch)) {
-              val newProducerId = producerIdManager.generateProducerId()
-              Right(txnMetadata.prepareProducerIdRotation(newProducerId, 
transactionTimeoutMs, time.milliseconds(),
-                expectedProducerIdAndEpoch.isDefined))
+
+              producerIdManager.generateProducerId() match {
+                case Success(producerId) =>
+                  Right(txnMetadata.prepareProducerIdRotation(producerId, 
transactionTimeoutMs, time.milliseconds(),
+                    expectedProducerIdAndEpoch.isDefined))
+                case Failure(exception) =>
+                  Left(Errors.forException(exception))
+              }
             } else {
               txnMetadata.prepareIncrementProducerEpoch(transactionTimeoutMs, 
expectedProducerIdAndEpoch.map(_.epoch),
                 time.milliseconds())
diff --git a/core/src/main/scala/kafka/server/BrokerServer.scala 
b/core/src/main/scala/kafka/server/BrokerServer.scala
index 2bf29c32d97..edb645f561c 100644
--- a/core/src/main/scala/kafka/server/BrokerServer.scala
+++ b/core/src/main/scala/kafka/server/BrokerServer.scala
@@ -295,9 +295,9 @@ class BrokerServer(
 
       val producerIdManagerSupplier = () => ProducerIdManager.rpc(
         config.brokerId,
+        time,
         brokerEpochSupplier = () => lifecycleManager.brokerEpoch,
-        clientToControllerChannelManager,
-        config.requestTimeoutMs
+        clientToControllerChannelManager
       )
 
       // Create transaction coordinator, but don't start it until we've 
started replica manager.
diff --git a/core/src/main/scala/kafka/server/KafkaServer.scala 
b/core/src/main/scala/kafka/server/KafkaServer.scala
index 10acd74241c..28c07840973 100755
--- a/core/src/main/scala/kafka/server/KafkaServer.scala
+++ b/core/src/main/scala/kafka/server/KafkaServer.scala
@@ -461,9 +461,9 @@ class KafkaServer(
         val producerIdManager = if 
(config.interBrokerProtocolVersion.isAllocateProducerIdsSupported) {
           ProducerIdManager.rpc(
             config.brokerId,
+            time,
             brokerEpochSupplier = brokerEpochSupplier,
-            clientToControllerChannelManager,
-            config.requestTimeoutMs
+            clientToControllerChannelManager
           )
         } else {
           ProducerIdManager.zk(config.brokerId, zkClient)
diff --git 
a/core/src/test/scala/integration/kafka/coordinator/transaction/ProducerIdsIntegrationTest.scala
 
b/core/src/test/scala/integration/kafka/coordinator/transaction/ProducerIdsIntegrationTest.scala
index 558f0041e0a..519c2bcf088 100644
--- 
a/core/src/test/scala/integration/kafka/coordinator/transaction/ProducerIdsIntegrationTest.scala
+++ 
b/core/src/test/scala/integration/kafka/coordinator/transaction/ProducerIdsIntegrationTest.scala
@@ -24,14 +24,16 @@ import kafka.test.junit.ClusterTestExtensions
 import kafka.test.{ClusterConfig, ClusterInstance}
 import org.apache.kafka.common.message.InitProducerIdRequestData
 import org.apache.kafka.common.network.ListenerName
+import org.apache.kafka.common.protocol.Errors
 import org.apache.kafka.common.record.RecordBatch
 import org.apache.kafka.common.requests.{InitProducerIdRequest, 
InitProducerIdResponse}
 import org.apache.kafka.server.common.MetadataVersion
-import org.junit.jupiter.api.Assertions.assertEquals
-import org.junit.jupiter.api.BeforeEach
+import org.junit.jupiter.api.Assertions.{assertEquals, assertTrue}
+import org.junit.jupiter.api.{BeforeEach, Disabled, Timeout}
 import org.junit.jupiter.api.extension.ExtendWith
 
 import java.util.stream.{Collectors, IntStream}
+import scala.concurrent.duration.DurationInt
 import scala.jdk.CollectionConverters._
 
 @ExtendWith(value = Array(classOf[ClusterTestExtensions]))
@@ -61,27 +63,59 @@ class ProducerIdsIntegrationTest {
     clusterInstance.stop()
   }
 
+  @ClusterTest(clusterType = Type.ZK, brokers = 1, autoStart = AutoStart.NO)
+  @Timeout(20)
+  def testHandleAllocateProducerIdsSingleRequestHandlerThread(clusterInstance: 
ClusterInstance): Unit = {
+    
clusterInstance.config().serverProperties().put(KafkaConfig.NumIoThreadsProp, 
"1")
+    clusterInstance.start()
+    verifyUniqueIds(clusterInstance)
+    clusterInstance.stop()
+  }
+
+  @Disabled // TODO: Enable once producer id block size is configurable 
(KAFKA-15029)
+  @ClusterTest(clusterType = Type.ZK, brokers = 1, autoStart = AutoStart.NO)
+  def testMultipleAllocateProducerIdsRequest(clusterInstance: 
ClusterInstance): Unit = {
+    
clusterInstance.config().serverProperties().put(KafkaConfig.NumIoThreadsProp, 
"2")
+    clusterInstance.start()
+    verifyUniqueIds(clusterInstance)
+    clusterInstance.stop()
+  }
+
   private def verifyUniqueIds(clusterInstance: ClusterInstance): Unit = {
-    // Request enough PIDs from each broker to ensure each broker generates 
two PID blocks
+    // Request enough PIDs from each broker to ensure each broker generates 
two blocks
     val ids = clusterInstance.brokerSocketServers().stream().flatMap( broker 
=> {
-      IntStream.range(0, 1001).parallel().mapToObj( _ => 
nextProducerId(broker, clusterInstance.clientListener()))
-    }).collect(Collectors.toList[Long]).asScala.toSeq
+      IntStream.range(0, 1001).parallel().mapToObj( _ =>
+        nextProducerId(broker, clusterInstance.clientListener())
+      )}).collect(Collectors.toList[Long]).asScala.toSeq
 
-    assertEquals(3003, ids.size, "Expected exactly 3003 IDs")
-    assertEquals(ids.size, ids.distinct.size, "Found duplicate producer IDs")
+    val brokerCount = clusterInstance.brokerIds.size
+    val expectedTotalCount = 1001 * brokerCount
+    assertEquals(expectedTotalCount, ids.size, s"Expected exactly 
$expectedTotalCount IDs")
+    assertEquals(expectedTotalCount, ids.distinct.size, "Found duplicate 
producer IDs")
   }
 
   private def nextProducerId(broker: SocketServer, listener: ListenerName): 
Long = {
-    val data = new InitProducerIdRequestData()
-      .setProducerEpoch(RecordBatch.NO_PRODUCER_EPOCH)
-      .setProducerId(RecordBatch.NO_PRODUCER_ID)
-      .setTransactionalId(null)
-      .setTransactionTimeoutMs(10)
-    val request = new InitProducerIdRequest.Builder(data).build()
+    // Generating producer ids may fail while waiting for the initial block 
and also
+    // when the current block is full and waiting for the prefetched block.
+    val deadline = 5.seconds.fromNow
+    var shouldRetry = true
+    var response: InitProducerIdResponse = null
+    while(shouldRetry && deadline.hasTimeLeft()) {
+      val data = new InitProducerIdRequestData()
+        .setProducerEpoch(RecordBatch.NO_PRODUCER_EPOCH)
+        .setProducerId(RecordBatch.NO_PRODUCER_ID)
+        .setTransactionalId(null)
+        .setTransactionTimeoutMs(10)
+      val request = new InitProducerIdRequest.Builder(data).build()
+
+      response = 
IntegrationTestUtils.connectAndReceive[InitProducerIdResponse](request,
+        destination = broker,
+        listenerName = listener)
 
-    val response = 
IntegrationTestUtils.connectAndReceive[InitProducerIdResponse](request,
-      destination = broker,
-      listenerName = listener)
+      shouldRetry = response.data.errorCode == 
Errors.COORDINATOR_LOAD_IN_PROGRESS.code
+    }
+    assertTrue(deadline.hasTimeLeft())
+    assertEquals(Errors.NONE.code, response.data.errorCode)
     response.data().producerId()
   }
 }
diff --git 
a/core/src/test/scala/unit/kafka/coordinator/transaction/ProducerIdManagerTest.scala
 
b/core/src/test/scala/unit/kafka/coordinator/transaction/ProducerIdManagerTest.scala
index 666a3c363ff..73b208196e6 100644
--- 
a/core/src/test/scala/unit/kafka/coordinator/transaction/ProducerIdManagerTest.scala
+++ 
b/core/src/test/scala/unit/kafka/coordinator/transaction/ProducerIdManagerTest.scala
@@ -16,13 +16,16 @@
  */
 package kafka.coordinator.transaction
 
+import kafka.coordinator.transaction.ProducerIdManager.RetryBackoffMs
 import kafka.server.BrokerToControllerChannelManager
+import kafka.utils.TestUtils
 import kafka.zk.{KafkaZkClient, ProducerIdBlockZNode}
 import org.apache.kafka.common.KafkaException
 import org.apache.kafka.common.errors.CoordinatorLoadInProgressException
 import org.apache.kafka.common.message.AllocateProducerIdsResponseData
 import org.apache.kafka.common.protocol.Errors
 import org.apache.kafka.common.requests.AllocateProducerIdsResponse
+import org.apache.kafka.common.utils.{MockTime, Time}
 import org.apache.kafka.server.common.ProducerIdsBlock
 import org.junit.jupiter.api.Assertions._
 import org.junit.jupiter.api.Test
@@ -31,7 +34,11 @@ import org.junit.jupiter.params.provider.{EnumSource, 
ValueSource}
 import org.mockito.ArgumentCaptor
 import org.mockito.ArgumentMatchers.{any, anyString}
 import org.mockito.Mockito.{mock, when}
-import java.util.stream.IntStream
+
+import java.util.concurrent.{CountDownLatch, Executors, TimeUnit}
+import java.util.concurrent.atomic.AtomicBoolean
+import scala.collection.mutable
+import scala.util.{Failure, Success}
 
 class ProducerIdManagerTest {
 
@@ -39,20 +46,48 @@ class ProducerIdManagerTest {
   val zkClient: KafkaZkClient = mock(classOf[KafkaZkClient])
 
   // Mutable test implementation that lets us easily set the idStart and error
-  class MockProducerIdManager(val brokerId: Int, var idStart: Long, val idLen: 
Int, var error: Errors = Errors.NONE, timeout: Boolean = false)
-    extends RPCProducerIdManager(brokerId, () => 1, brokerToController, 100) {
+  class MockProducerIdManager(
+    val brokerId: Int,
+    var idStart: Long,
+    val idLen: Int,
+    var error: Errors = Errors.NONE,
+    val isErroneousBlock: Boolean = false,
+    val time: Time = Time.SYSTEM,
+    var remainingRetries: Int = 1
+  ) extends RPCProducerIdManager(brokerId, time, () => 1, brokerToController) {
+
+    private val brokerToControllerRequestExecutor = 
Executors.newSingleThreadExecutor()
+    val capturedFailure: AtomicBoolean = new AtomicBoolean(false)
 
     override private[transaction] def sendRequest(): Unit = {
-      if (timeout)
-        return
 
-      if (error == Errors.NONE) {
-        handleAllocateProducerIdsResponse(new AllocateProducerIdsResponse(
-          new 
AllocateProducerIdsResponseData().setProducerIdStart(idStart).setProducerIdLen(idLen)))
-        idStart += idLen
+      brokerToControllerRequestExecutor.submit(() => {
+        if (error == Errors.NONE) {
+          handleAllocateProducerIdsResponse(new AllocateProducerIdsResponse(
+            new 
AllocateProducerIdsResponseData().setProducerIdStart(idStart).setProducerIdLen(idLen)))
+          if (!isErroneousBlock) {
+            idStart += idLen
+          }
+        } else {
+          handleAllocateProducerIdsResponse(new AllocateProducerIdsResponse(
+            new AllocateProducerIdsResponseData().setErrorCode(error.code)))
+        }
+      }, 0)
+    }
+
+    override private[transaction] def 
handleAllocateProducerIdsResponse(response: AllocateProducerIdsResponse): Unit 
= {
+      super.handleAllocateProducerIdsResponse(response)
+      capturedFailure.set(nextProducerIdBlock.get == null)
+    }
+
+    override private[transaction] def maybeRequestNextBlock(): Unit = {
+      if (error == Errors.NONE && !isErroneousBlock) {
+        super.maybeRequestNextBlock()
       } else {
-        handleAllocateProducerIdsResponse(new AllocateProducerIdsResponse(
-          new AllocateProducerIdsResponseData().setErrorCode(error.code)))
+        if (remainingRetries > 0) {
+          super.maybeRequestNextBlock()
+          remainingRetries -= 1
+        }
       }
     }
   }
@@ -80,26 +115,20 @@ class ProducerIdManagerTest {
     val manager1 = new ZkProducerIdManager(0, zkClient)
     val manager2 = new ZkProducerIdManager(1, zkClient)
 
-    val pid1 = manager1.generateProducerId()
-    val pid2 = manager2.generateProducerId()
+    val pid1 = manager1.generateProducerId().get
+    val pid2 = manager2.generateProducerId().get
 
     assertEquals(0, pid1)
     assertEquals(ProducerIdsBlock.PRODUCER_ID_BLOCK_SIZE, pid2)
 
     for (i <- 1L until ProducerIdsBlock.PRODUCER_ID_BLOCK_SIZE)
-      assertEquals(pid1 + i, manager1.generateProducerId())
+      assertEquals(pid1 + i, manager1.generateProducerId().get)
 
     for (i <- 1L until ProducerIdsBlock.PRODUCER_ID_BLOCK_SIZE)
-      assertEquals(pid2 + i, manager2.generateProducerId())
-
-    assertEquals(pid2 + ProducerIdsBlock.PRODUCER_ID_BLOCK_SIZE, 
manager1.generateProducerId())
-    assertEquals(pid2 + ProducerIdsBlock.PRODUCER_ID_BLOCK_SIZE * 2, 
manager2.generateProducerId())
-  }
+      assertEquals(pid2 + i, manager2.generateProducerId().get)
 
-  @Test
-  def testRPCProducerIdManagerThrowsConcurrentTransactions(): Unit = {
-    val manager1 = new MockProducerIdManager(0, 0, 0, timeout = true)
-    assertThrows(classOf[CoordinatorLoadInProgressException], () => 
manager1.generateProducerId())
+    assertEquals(pid2 + ProducerIdsBlock.PRODUCER_ID_BLOCK_SIZE, 
manager1.generateProducerId().get)
+    assertEquals(pid2 + ProducerIdsBlock.PRODUCER_ID_BLOCK_SIZE * 2, 
manager2.generateProducerId().get)
   }
 
   @Test
@@ -113,38 +142,113 @@ class ProducerIdManagerTest {
   }
 
   @ParameterizedTest
-  @ValueSource(ints = Array(1, 2, 10))
-  def testContiguousIds(idBlockLen: Int): Unit = {
+  @ValueSource(ints = Array(1, 2, 10, 100))
+  def testConcurrentGeneratePidRequests(idBlockLen: Int): Unit = {
+    // Send concurrent generateProducerId requests. Ensure that the generated 
producer id is unique.
+    // For each block (total 3 blocks), only "idBlockLen" number of requests 
should go through.
+    // All other requests should fail immediately.
+
+    val numThreads = 5
+    val latch = new CountDownLatch(idBlockLen * 3)
     val manager = new MockProducerIdManager(0, 0, idBlockLen)
-
-    IntStream.range(0, idBlockLen * 3).forEach { i =>
-      assertEquals(i, manager.generateProducerId())
+    val pidMap = mutable.Map[Long, Int]()
+    val requestHandlerThreadPool = Executors.newFixedThreadPool(numThreads)
+
+    for ( _ <- 0 until numThreads) {
+      requestHandlerThreadPool.submit(() => {
+        while(latch.getCount > 0) {
+          val result = manager.generateProducerId()
+          result match {
+            case Success(pid) =>
+              pidMap synchronized {
+                if (latch.getCount != 0) {
+                  val counter = pidMap.getOrElse(pid, 0)
+                  pidMap += pid -> (counter + 1)
+                  latch.countDown()
+                }
+              }
+
+            case Failure(exception) =>
+              assertEquals(classOf[CoordinatorLoadInProgressException], 
exception.getClass)
+          }
+          Thread.sleep(100)
+        }
+      }, 0)
+    }
+    assertTrue(latch.await(12000, TimeUnit.MILLISECONDS))
+    requestHandlerThreadPool.shutdown()
+
+    assertEquals(idBlockLen * 3, pidMap.size)
+    pidMap.foreach { case (pid, count) =>
+      assertEquals(1, count)
+      assertTrue(pid < (3 * idBlockLen) + numThreads, s"Unexpected pid $pid; " 
+
+        s"non-contiguous blocks generated or did not fully exhaust blocks.")
     }
   }
 
   @ParameterizedTest
   @EnumSource(value = classOf[Errors], names = Array("UNKNOWN_SERVER_ERROR", 
"INVALID_REQUEST"))
   def testUnrecoverableErrors(error: Errors): Unit = {
-    val manager = new MockProducerIdManager(0, 0, 1)
-    assertEquals(0, manager.generateProducerId())
+    val time = new MockTime()
+    val manager = new MockProducerIdManager(0, 0, 1, time = time)
+
+    verifyNewBlockAndProducerId(manager, new ProducerIdsBlock(0, 0, 1), 0)
 
     manager.error = error
-    assertThrows(classOf[Throwable], () => manager.generateProducerId())
+    verifyFailure(manager)
 
     manager.error = Errors.NONE
-    assertEquals(1, manager.generateProducerId())
+    time.sleep(RetryBackoffMs)
+    verifyNewBlockAndProducerId(manager, new ProducerIdsBlock(0, 1, 1), 1)
   }
 
   @Test
   def testInvalidRanges(): Unit = {
-    var manager = new MockProducerIdManager(0, -1, 10)
-    assertThrows(classOf[KafkaException], () => manager.generateProducerId())
+    var manager = new MockProducerIdManager(0, -1, 10, isErroneousBlock = true)
+    verifyFailure(manager)
+
+    manager = new MockProducerIdManager(0, 0, -1, isErroneousBlock = true)
+    verifyFailure(manager)
+
+    manager = new MockProducerIdManager(0, Long.MaxValue-1, 10, 
isErroneousBlock = true)
+    verifyFailure(manager)
+  }
+
+  @Test
+  def testRetryBackoff(): Unit = {
+    val time = new MockTime()
+    val manager = new MockProducerIdManager(0, 0, 1,
+      error = Errors.UNKNOWN_SERVER_ERROR, time = time, remainingRetries = 2)
+
+    verifyFailure(manager)
+    manager.error = Errors.NONE
+
+    // We should only get a new block once retry backoff ms has passed.
+    assertEquals(classOf[CoordinatorLoadInProgressException], 
manager.generateProducerId().failed.get.getClass)
+    time.sleep(RetryBackoffMs)
+    verifyNewBlockAndProducerId(manager, new ProducerIdsBlock(0, 0, 1), 0)
+  }
+
+  private def verifyFailure(manager: MockProducerIdManager): Unit = {
+    assertEquals(classOf[CoordinatorLoadInProgressException], 
manager.generateProducerId().failed.get.getClass)
+    TestUtils.waitUntilTrue(() => {
+      manager synchronized {
+        manager.capturedFailure.get
+      }
+    }, "Expected failure")
+    manager.capturedFailure.set(false)
+  }
 
-    manager = new MockProducerIdManager(0, 0, -1)
-    assertThrows(classOf[KafkaException], () => manager.generateProducerId())
+  private def verifyNewBlockAndProducerId(manager: MockProducerIdManager,
+                                          expectedBlock: ProducerIdsBlock,
+                                          expectedPid: Long): Unit = {
 
-    manager = new MockProducerIdManager(0, Long.MaxValue-1, 10)
-    assertThrows(classOf[KafkaException], () => manager.generateProducerId())
+    assertEquals(classOf[CoordinatorLoadInProgressException], 
manager.generateProducerId().failed.get.getClass)
+    TestUtils.waitUntilTrue(() => {
+      val nextBlock = manager.nextProducerIdBlock.get
+      nextBlock != null && nextBlock.equals(expectedBlock)
+    }, "failed to generate block")
+    assertEquals(expectedPid, manager.generateProducerId().get)
   }
 }
 
diff --git 
a/core/src/test/scala/unit/kafka/coordinator/transaction/TransactionCoordinatorConcurrencyTest.scala
 
b/core/src/test/scala/unit/kafka/coordinator/transaction/TransactionCoordinatorConcurrencyTest.scala
index 9a0d8143766..d57a8e974c6 100644
--- 
a/core/src/test/scala/unit/kafka/coordinator/transaction/TransactionCoordinatorConcurrencyTest.scala
+++ 
b/core/src/test/scala/unit/kafka/coordinator/transaction/TransactionCoordinatorConcurrencyTest.scala
@@ -43,6 +43,7 @@ import org.mockito.Mockito.{mock, when}
 
 import scala.jdk.CollectionConverters._
 import scala.collection.{Map, mutable}
+import scala.util.Success
 
 class TransactionCoordinatorConcurrencyTest extends 
AbstractCoordinatorConcurrencyTest[Transaction] {
   private val nTransactions = nThreads * 10
@@ -82,7 +83,11 @@ class TransactionCoordinatorConcurrencyTest extends 
AbstractCoordinatorConcurren
 
     val pidGenerator: ProducerIdManager = mock(classOf[ProducerIdManager])
     when(pidGenerator.generateProducerId())
-      .thenAnswer(_ => if (bumpProducerId) producerId + 1 else producerId)
+      .thenAnswer(_ => if (bumpProducerId) {
+        Success(producerId + 1)
+      } else {
+        Success(producerId)
+      })
     val brokerNode = new Node(0, "host", 10)
     val metadataCache: MetadataCache = mock(classOf[MetadataCache])
     when(metadataCache.getPartitionLeaderEndpoint(
diff --git 
a/core/src/test/scala/unit/kafka/coordinator/transaction/TransactionCoordinatorTest.scala
 
b/core/src/test/scala/unit/kafka/coordinator/transaction/TransactionCoordinatorTest.scala
index ab8e1052f93..a5e2d57d87f 100644
--- 
a/core/src/test/scala/unit/kafka/coordinator/transaction/TransactionCoordinatorTest.scala
+++ 
b/core/src/test/scala/unit/kafka/coordinator/transaction/TransactionCoordinatorTest.scala
@@ -31,6 +31,7 @@ import org.mockito.Mockito.{mock, times, verify, when}
 
 import scala.collection.mutable
 import scala.jdk.CollectionConverters._
+import scala.util.Success
 
 class TransactionCoordinatorTest {
 
@@ -46,7 +47,7 @@ class TransactionCoordinatorTest {
   val brokerId = 0
   val coordinatorEpoch = 0
   private val transactionalId = "known"
-  private val producerId = 10
+  private val producerId = 10L
   private val producerEpoch: Short = 1
   private val txnTimeoutMs = 1
 
@@ -68,7 +69,7 @@ class TransactionCoordinatorTest {
   private def mockPidGenerator(): Unit = {
     when(pidGenerator.generateProducerId()).thenAnswer(_ => {
       nextPid += 1
-      nextPid - 1
+      Success(nextPid - 1)
     })
   }
 
@@ -908,7 +909,7 @@ class TransactionCoordinatorTest {
       (Short.MaxValue - 2).toShort, txnTimeoutMs, Empty, partitions, 
time.milliseconds, time.milliseconds)
 
     when(pidGenerator.generateProducerId())
-      .thenReturn(producerId + 1)
+      .thenReturn(Success(producerId + 1))
 
     when(transactionManager.validateTransactionTimeoutMs(anyInt()))
       .thenReturn(true)
@@ -949,7 +950,7 @@ class TransactionCoordinatorTest {
       (Short.MaxValue - 2).toShort, txnTimeoutMs, Empty, partitions, 
time.milliseconds, time.milliseconds)
 
     when(pidGenerator.generateProducerId())
-      .thenReturn(producerId + 1)
+      .thenReturn(Success(producerId + 1))
 
     when(transactionManager.validateTransactionTimeoutMs(anyInt()))
       .thenReturn(true)
@@ -1208,7 +1209,7 @@ class TransactionCoordinatorTest {
 
   private def validateIncrementEpochAndUpdateMetadata(state: 
TransactionState): Unit = {
     when(pidGenerator.generateProducerId())
-      .thenReturn(producerId)
+      .thenReturn(Success(producerId))
 
     when(transactionManager.validateTransactionTimeoutMs(anyInt()))
       .thenReturn(true)
diff --git 
a/core/src/test/scala/unit/kafka/server/AddPartitionsToTxnRequestServerTest.scala
 
b/core/src/test/scala/unit/kafka/server/AddPartitionsToTxnRequestServerTest.scala
index e59ed821c21..6e296c2892b 100644
--- 
a/core/src/test/scala/unit/kafka/server/AddPartitionsToTxnRequestServerTest.scala
+++ 
b/core/src/test/scala/unit/kafka/server/AddPartitionsToTxnRequestServerTest.scala
@@ -174,8 +174,13 @@ class AddPartitionsToTxnRequestServerTest extends 
BaseRequestTest {
     val findCoordinatorResponse = 
connectAndReceive[FindCoordinatorResponse](findCoordinatorRequest, 
brokerSocketServer(brokers.head.config.brokerId))
     val coordinatorId = 
findCoordinatorResponse.data().coordinators().get(0).nodeId()
 
-    val initPidRequest = new InitProducerIdRequest.Builder(new 
InitProducerIdRequestData().setTransactionalId(transactionalId).setTransactionTimeoutMs(10000)).build()
-    val initPidResponse = 
connectAndReceive[InitProducerIdResponse](initPidRequest, 
brokerSocketServer(coordinatorId))
+    var initPidResponse: InitProducerIdResponse = null
+
+    TestUtils.waitUntilTrue(() => {
+      val initPidRequest = new InitProducerIdRequest.Builder(new 
InitProducerIdRequestData().setTransactionalId(transactionalId).setTransactionTimeoutMs(10000)).build()
+      initPidResponse = 
connectAndReceive[InitProducerIdResponse](initPidRequest, 
brokerSocketServer(coordinatorId))
+      initPidResponse.error() != Errors.COORDINATOR_LOAD_IN_PROGRESS
+    }, "Failed to get a valid InitProducerIdResponse.")
 
     val producerId1 = initPidResponse.data().producerId()
     val producerEpoch1 = initPidResponse.data().producerEpoch()
diff --git 
a/server-common/src/main/java/org/apache/kafka/server/common/ProducerIdsBlock.java
 
b/server-common/src/main/java/org/apache/kafka/server/common/ProducerIdsBlock.java
index b2633bf7034..c4240018f9b 100644
--- 
a/server-common/src/main/java/org/apache/kafka/server/common/ProducerIdsBlock.java
+++ 
b/server-common/src/main/java/org/apache/kafka/server/common/ProducerIdsBlock.java
@@ -18,6 +18,8 @@
 package org.apache.kafka.server.common;
 
 import java.util.Objects;
+import java.util.Optional;
+import java.util.concurrent.atomic.AtomicLong;
 
 /**
  * Holds a range of Producer IDs used for Transactional and EOS producers.
@@ -32,11 +34,25 @@ public class ProducerIdsBlock {
     private final int assignedBrokerId;
     private final long firstProducerId;
     private final int blockSize;
+    private final AtomicLong producerIdCounter;
 
     public ProducerIdsBlock(int assignedBrokerId, long firstProducerId, int 
blockSize) {
         this.assignedBrokerId = assignedBrokerId;
         this.firstProducerId = firstProducerId;
         this.blockSize = blockSize;
+        producerIdCounter = new AtomicLong(firstProducerId);
+    }
+
+    /**
+     * Claim the next available producer id from the block.
+     * Returns an empty result if there are no more available producer ids in 
the block.
+     */
+    public Optional<Long> claimNextId() {
+        long nextId = producerIdCounter.getAndIncrement();
+        if (nextId > lastProducerId()) {
+            return Optional.empty();
+        }
+        return Optional.of(nextId);
     }
 
     /**
diff --git 
a/server-common/src/test/java/org/apache/kafka/server/common/ProducerIdsBlockTest.java
 
b/server-common/src/test/java/org/apache/kafka/server/common/ProducerIdsBlockTest.java
index f15c171a2c0..ea7973d7264 100644
--- 
a/server-common/src/test/java/org/apache/kafka/server/common/ProducerIdsBlockTest.java
+++ 
b/server-common/src/test/java/org/apache/kafka/server/common/ProducerIdsBlockTest.java
@@ -18,7 +18,14 @@ package org.apache.kafka.server.common;
 
 import org.junit.jupiter.api.Test;
 
+import java.util.Optional;
+import java.util.concurrent.CompletableFuture;
+import java.util.concurrent.CountDownLatch;
+import java.util.concurrent.TimeUnit;
+import java.util.concurrent.atomic.AtomicLong;
+
 import static org.junit.jupiter.api.Assertions.assertEquals;
+import static org.junit.jupiter.api.Assertions.assertTrue;
 
 class ProducerIdsBlockTest {
 
@@ -43,4 +50,22 @@ class ProducerIdsBlockTest {
         assertEquals(brokerId, block.assignedBrokerId());
     }
 
+    @Test
+    public void testClaimNextId() throws Exception {
+        for (int i = 0; i < 50; i++) {
+            ProducerIdsBlock block = new ProducerIdsBlock(0, 1, 1);
+            CountDownLatch latch = new CountDownLatch(1);
+            AtomicLong counter = new AtomicLong(0);
+            CompletableFuture.runAsync(() -> {
+                Optional<Long> pid = block.claimNextId();
+                counter.addAndGet(pid.orElse(0L));
+                latch.countDown();
+            });
+            Optional<Long> pid = block.claimNextId();
+            counter.addAndGet(pid.orElse(0L));
+            assertTrue(latch.await(1, TimeUnit.SECONDS));
+            assertEquals(1, counter.get());
+        }
+    }
+
 }
\ No newline at end of file


Reply via email to