This is an automated email from the ASF dual-hosted git repository.

jolshan pushed a commit to branch trunk
in repository https://gitbox.apache.org/repos/asf/kafka.git


The following commit(s) were added to refs/heads/trunk by this push:
     new f38b0d886c0 KAFKA-15661: KIP-951: Server side changes (#14444)
f38b0d886c0 is described below

commit f38b0d886c01bf4e83a08ca7c8a4da01c2d0c686
Author: Crispin Bernier <[email protected]>
AuthorDate: Fri Nov 10 00:07:21 2023 -0500

    KAFKA-15661: KIP-951: Server side changes (#14444)
    
    This is the server side changes to populate the fields in KIP-951. On 
NOT_LEADER_OR_FOLLOWER errors in both FETCH and PRODUCE the new leader ID and 
epoch are retrieved from the local cache through ReplicaManager and included in 
the response, falling back to the metadata cache if they are unavailable there. 
The endpoint for the new leader is retrieved from the metadata cache. The new 
fields are all optional (tagged) and an IBP bump was required.
    
    
https://cwiki.apache.org/confluence/display/KAFKA/KIP-951%3A+Leader+discovery+optimisations+for+the+client
    
    https://issues.apache.org/jira/browse/KAFKA-15661
    
    Protocol changes: #14627
    
    Testing
    Benchmarking described here 
https://cwiki.apache.org/confluence/display/KAFKA/KIP-951%3A+Leader+discovery+optimisations+for+the+client#KIP951:Leaderdiscoveryoptimisationsfortheclient-BenchmarkResults
    ./gradlew core:test --tests kafka.server.KafkaApisTest
    
    Reviewers: Justine Olshan <[email protected]>, David Jacot 
<[email protected]>, Jason Gustafson <[email protected]>, Fred Zheng 
<[email protected]>, Mayank Shekhar Narula <[email protected]>,  
Yang Yang <[email protected]>, David Mao <[email protected]>, Kirk True 
<[email protected]>
---
 .../kafka/common/requests/ProduceResponse.java     |  26 +-
 core/src/main/scala/kafka/server/KafkaApis.scala   |  53 +++-
 .../scala/unit/kafka/server/KafkaApisTest.scala    | 270 ++++++++++++++++++++-
 3 files changed, 342 insertions(+), 7 deletions(-)

diff --git 
a/clients/src/main/java/org/apache/kafka/common/requests/ProduceResponse.java 
b/clients/src/main/java/org/apache/kafka/common/requests/ProduceResponse.java
index a8c2a801e4c..186ad9b80a1 100644
--- 
a/clients/src/main/java/org/apache/kafka/common/requests/ProduceResponse.java
+++ 
b/clients/src/main/java/org/apache/kafka/common/requests/ProduceResponse.java
@@ -16,6 +16,7 @@
  */
 package org.apache.kafka.common.requests;
 
+import org.apache.kafka.common.Node;
 import org.apache.kafka.common.TopicPartition;
 import org.apache.kafka.common.message.ProduceResponseData;
 import org.apache.kafka.common.message.ProduceResponseData.LeaderIdAndEpoch;
@@ -72,7 +73,7 @@ public class ProduceResponse extends AbstractResponse {
      */
     @Deprecated
     public ProduceResponse(Map<TopicPartition, PartitionResponse> responses) {
-        this(responses, DEFAULT_THROTTLE_TIME);
+        this(responses, DEFAULT_THROTTLE_TIME, Collections.emptyList());
     }
 
     /**
@@ -83,10 +84,23 @@ public class ProduceResponse extends AbstractResponse {
      */
     @Deprecated
     public ProduceResponse(Map<TopicPartition, PartitionResponse> responses, 
int throttleTimeMs) {
-        this(toData(responses, throttleTimeMs));
+        this(toData(responses, throttleTimeMs, Collections.emptyList()));
     }
 
-    private static ProduceResponseData toData(Map<TopicPartition, 
PartitionResponse> responses, int throttleTimeMs) {
+    /**
+     * Constructor for the latest version
+     * This is deprecated in favor of using the ProduceResponseData 
constructor, KafkaApis should switch to that
+     * in KAFKA-10730
+     * @param responses Produced data grouped by topic-partition
+     * @param throttleTimeMs Time in milliseconds the response was throttled
+     * @param nodeEndpoints List of node endpoints
+     */
+    @Deprecated
+    public ProduceResponse(Map<TopicPartition, PartitionResponse> responses, 
int throttleTimeMs, List<Node> nodeEndpoints) {
+        this(toData(responses, throttleTimeMs, nodeEndpoints));
+    }
+
+    private static ProduceResponseData toData(Map<TopicPartition, 
PartitionResponse> responses, int throttleTimeMs, List<Node> nodeEndpoints) {
         ProduceResponseData data = new 
ProduceResponseData().setThrottleTimeMs(throttleTimeMs);
         responses.forEach((tp, response) -> {
             ProduceResponseData.TopicProduceResponse tpr = 
data.responses().find(tp.topic());
@@ -110,6 +124,12 @@ public class ProduceResponse extends AbstractResponse {
                             .setBatchIndexErrorMessage(e.message))
                         .collect(Collectors.toList())));
         });
+        nodeEndpoints.forEach(endpoint -> data.nodeEndpoints()
+                .add(new ProduceResponseData.NodeEndpoint()
+                        .setNodeId(endpoint.id())
+                        .setHost(endpoint.host())
+                        .setPort(endpoint.port())
+                        .setRack(endpoint.rack())));
         return data;
     }
 
diff --git a/core/src/main/scala/kafka/server/KafkaApis.scala 
b/core/src/main/scala/kafka/server/KafkaApis.scala
index bd0959d40d2..3f5a435d147 100644
--- a/core/src/main/scala/kafka/server/KafkaApis.scala
+++ b/core/src/main/scala/kafka/server/KafkaApis.scala
@@ -562,6 +562,23 @@ class KafkaApis(val requestChannel: RequestChannel,
     }
   }
 
+  case class LeaderNode(leaderId: Int, leaderEpoch: Int, node: Option[Node])
+
+  private def getCurrentLeader(tp: TopicPartition, ln: ListenerName): 
LeaderNode = {
+    val partitionInfoOrError = replicaManager.getPartitionOrError(tp)
+    val (leaderId, leaderEpoch) = partitionInfoOrError match {
+      case Right(x) =>
+        (x.leaderReplicaIdOpt.getOrElse(-1), x.getLeaderEpoch)
+      case Left(x) =>
+        debug(s"Unable to retrieve local leaderId and Epoch with error $x, 
falling back to metadata cache")
+        metadataCache.getPartitionInfo(tp.topic, tp.partition) match {
+          case Some(pinfo) => (pinfo.leader(), pinfo.leaderEpoch())
+          case None => (-1, -1)
+        }
+    }
+    LeaderNode(leaderId, leaderEpoch, 
metadataCache.getAliveBrokerNode(leaderId, ln))
+  }
+
   /**
    * Handle a produce request
    */
@@ -614,6 +631,7 @@ class KafkaApis(val requestChannel: RequestChannel,
       val mergedResponseStatus = responseStatus ++ unauthorizedTopicResponses 
++ nonExistingTopicResponses ++ invalidRequestResponses
       var errorInResponse = false
 
+      val nodeEndpoints = new mutable.HashMap[Int, Node]
       mergedResponseStatus.forKeyValue { (topicPartition, status) =>
         if (status.error != Errors.NONE) {
           errorInResponse = true
@@ -622,6 +640,20 @@ class KafkaApis(val requestChannel: RequestChannel,
             request.header.clientId,
             topicPartition,
             status.error.exceptionName))
+
+          if (request.header.apiVersion >= 10) {
+            status.error match {
+              case Errors.NOT_LEADER_OR_FOLLOWER =>
+                val leaderNode = getCurrentLeader(topicPartition, 
request.context.listenerName)
+                leaderNode.node.foreach { node =>
+                  nodeEndpoints.put(node.id(), node)
+                }
+                status.currentLeader
+                  .setLeaderId(leaderNode.leaderId)
+                  .setLeaderEpoch(leaderNode.leaderEpoch)
+                case _ =>
+            }
+          }
         }
       }
 
@@ -665,7 +697,7 @@ class KafkaApis(val requestChannel: RequestChannel,
           requestHelper.sendNoOpResponseExemptThrottle(request)
         }
       } else {
-        requestChannel.sendResponse(request, new 
ProduceResponse(mergedResponseStatus.asJava, maxThrottleTimeMs), None)
+        requestChannel.sendResponse(request, new 
ProduceResponse(mergedResponseStatus.asJava, maxThrottleTimeMs, 
nodeEndpoints.values.toList.asJava), None)
       }
     }
 
@@ -843,6 +875,7 @@ class KafkaApis(val requestChannel: RequestChannel,
               .setRecords(unconvertedRecords)
               .setPreferredReadReplica(partitionData.preferredReadReplica)
               .setDivergingEpoch(partitionData.divergingEpoch)
+              .setCurrentLeader(partitionData.currentLeader())
         }
       }
     }
@@ -851,6 +884,7 @@ class KafkaApis(val requestChannel: RequestChannel,
     def processResponseCallback(responsePartitionData: Seq[(TopicIdPartition, 
FetchPartitionData)]): Unit = {
       val partitions = new util.LinkedHashMap[TopicIdPartition, 
FetchResponseData.PartitionData]
       val reassigningPartitions = mutable.Set[TopicIdPartition]()
+      val nodeEndpoints = new mutable.HashMap[Int, Node]
       responsePartitionData.foreach { case (tp, data) =>
         val abortedTransactions = data.abortedTransactions.orElse(null)
         val lastStableOffset: Long = 
data.lastStableOffset.orElse(FetchResponse.INVALID_LAST_STABLE_OFFSET)
@@ -864,6 +898,21 @@ class KafkaApis(val requestChannel: RequestChannel,
           .setAbortedTransactions(abortedTransactions)
           .setRecords(data.records)
           
.setPreferredReadReplica(data.preferredReadReplica.orElse(FetchResponse.INVALID_PREFERRED_REPLICA_ID))
+
+        if (versionId >= 16) {
+          data.error match {
+            case Errors.NOT_LEADER_OR_FOLLOWER | Errors.FENCED_LEADER_EPOCH =>
+              val leaderNode = getCurrentLeader(tp.topicPartition(), 
request.context.listenerName)
+              leaderNode.node.foreach { node =>
+                nodeEndpoints.put(node.id(), node)
+              }
+              partitionData.currentLeader()
+                .setLeaderId(leaderNode.leaderId)
+                .setLeaderEpoch(leaderNode.leaderEpoch)
+            case _ =>
+          }
+        }
+
         data.divergingEpoch.ifPresent(partitionData.setDivergingEpoch(_))
         partitions.put(tp, partitionData)
       }
@@ -887,7 +936,7 @@ class KafkaApis(val requestChannel: RequestChannel,
 
         // Prepare fetch response from converted data
         val response =
-          FetchResponse.of(unconvertedFetchResponse.error, throttleTimeMs, 
unconvertedFetchResponse.sessionId, convertedData)
+          FetchResponse.of(unconvertedFetchResponse.error, throttleTimeMs, 
unconvertedFetchResponse.sessionId, convertedData, 
nodeEndpoints.values.toList.asJava)
         // record the bytes out metrics only when the response is being sent
         response.data.responses.forEach { topicResponse =>
           topicResponse.partitions.forEach { data =>
diff --git a/core/src/test/scala/unit/kafka/server/KafkaApisTest.scala 
b/core/src/test/scala/unit/kafka/server/KafkaApisTest.scala
index 17abdb0472d..41e67e61f5e 100644
--- a/core/src/test/scala/unit/kafka/server/KafkaApisTest.scala
+++ b/core/src/test/scala/unit/kafka/server/KafkaApisTest.scala
@@ -24,9 +24,10 @@ import java.util.Arrays.asList
 import java.util.concurrent.{CompletableFuture, TimeUnit}
 import java.util.{Collections, Optional, OptionalInt, OptionalLong, Properties}
 import kafka.api.LeaderAndIsr
-import kafka.cluster.Broker
+import kafka.cluster.{Broker, Partition}
 import kafka.controller.{ControllerContext, KafkaController}
 import kafka.coordinator.transaction.{InitProducerIdResult, 
TransactionCoordinator}
+import kafka.log.UnifiedLog
 import kafka.metrics.ClientMetricsTestUtils
 import kafka.network.{RequestChannel, RequestMetrics}
 import kafka.server.QuotaFactory.QuotaManagers
@@ -98,7 +99,7 @@ import org.apache.kafka.coordinator.group.GroupCoordinator
 import org.apache.kafka.server.common.{Features, MetadataVersion}
 import org.apache.kafka.server.common.MetadataVersion.{IBP_0_10_2_IV0, 
IBP_2_2_IV1}
 import org.apache.kafka.server.util.MockTime
-import org.apache.kafka.storage.internals.log.{AppendOrigin, FetchParams, 
FetchPartitionData}
+import org.apache.kafka.storage.internals.log.{AppendOrigin, FetchParams, 
FetchPartitionData, LogConfig}
 
 class KafkaApisTest {
   private val requestChannel: RequestChannel = mock(classOf[RequestChannel])
@@ -2475,6 +2476,204 @@ class KafkaApisTest {
     }
   }
 
+  @Test
+  def testProduceResponseContainsNewLeaderOnNotLeaderOrFollower(): Unit = {
+    val topic = "topic"
+    addTopicToMetadataCache(topic, numPartitions = 2, numBrokers = 3)
+
+    for (version <- 10 to ApiKeys.PRODUCE.latestVersion) {
+
+      reset(replicaManager, clientQuotaManager, clientRequestQuotaManager, 
requestChannel, txnCoordinator)
+
+      val responseCallback: ArgumentCaptor[Map[TopicPartition, 
PartitionResponse] => Unit] = 
ArgumentCaptor.forClass(classOf[Map[TopicPartition, PartitionResponse] => Unit])
+
+      val tp = new TopicPartition(topic, 0)
+      val partition = mock(classOf[Partition])
+      val newLeaderId = 2
+      val newLeaderEpoch = 5
+
+      val produceRequest = ProduceRequest.forCurrentMagic(new 
ProduceRequestData()
+        .setTopicData(new ProduceRequestData.TopicProduceDataCollection(
+          Collections.singletonList(new ProduceRequestData.TopicProduceData()
+            .setName(tp.topic).setPartitionData(Collections.singletonList(
+            new ProduceRequestData.PartitionProduceData()
+              .setIndex(tp.partition)
+              .setRecords(MemoryRecords.withRecords(CompressionType.NONE, new 
SimpleRecord("test".getBytes))))))
+            .iterator))
+        .setAcks(1.toShort)
+        .setTimeoutMs(5000))
+        .build(version.toShort)
+      val request = buildRequest(produceRequest)
+
+      when(replicaManager.appendRecords(anyLong,
+        anyShort,
+        ArgumentMatchers.eq(false),
+        ArgumentMatchers.eq(AppendOrigin.CLIENT),
+        any(),
+        responseCallback.capture(),
+        any(),
+        any(),
+        any(),
+        any(),
+        any())
+      ).thenAnswer(_ => responseCallback.getValue.apply(Map(tp -> new 
PartitionResponse(Errors.NOT_LEADER_OR_FOLLOWER))))
+
+      when(replicaManager.getPartitionOrError(tp)).thenAnswer(_ => 
Right(partition))
+      when(partition.leaderReplicaIdOpt).thenAnswer(_ => Some(newLeaderId))
+      when(partition.getLeaderEpoch).thenAnswer(_ => newLeaderEpoch)
+
+      
when(clientRequestQuotaManager.maybeRecordAndGetThrottleTimeMs(any[RequestChannel.Request](),
+        any[Long])).thenReturn(0)
+      when(clientQuotaManager.maybeRecordAndGetThrottleTimeMs(
+        any[RequestChannel.Request](), anyDouble, anyLong)).thenReturn(0)
+
+      createKafkaApis().handleProduceRequest(request, 
RequestLocal.withThreadConfinedCaching)
+
+      val response = verifyNoThrottling[ProduceResponse](request)
+
+      assertEquals(1, response.data.responses.size)
+      val topicProduceResponse = response.data.responses.asScala.head
+      assertEquals(1, topicProduceResponse.partitionResponses.size)
+      val partitionProduceResponse = 
topicProduceResponse.partitionResponses.asScala.head
+      assertEquals(Errors.NOT_LEADER_OR_FOLLOWER, 
Errors.forCode(partitionProduceResponse.errorCode))
+      assertEquals(newLeaderId, 
partitionProduceResponse.currentLeader.leaderId())
+      assertEquals(newLeaderEpoch, 
partitionProduceResponse.currentLeader.leaderEpoch())
+      assertEquals(1, response.data.nodeEndpoints.size)
+      val node = response.data.nodeEndpoints.asScala.head
+      assertEquals(2, node.nodeId)
+      assertEquals("broker2", node.host)
+    }
+  }
+
+  @Test
+  def testProduceResponseReplicaManagerLookupErrorOnNotLeaderOrFollower(): 
Unit = {
+    val topic = "topic"
+    addTopicToMetadataCache(topic, numPartitions = 2, numBrokers = 3)
+
+    for (version <- 10 to ApiKeys.PRODUCE.latestVersion) {
+
+      reset(replicaManager, clientQuotaManager, clientRequestQuotaManager, 
requestChannel, txnCoordinator)
+
+      val responseCallback: ArgumentCaptor[Map[TopicPartition, 
PartitionResponse] => Unit] = 
ArgumentCaptor.forClass(classOf[Map[TopicPartition, PartitionResponse] => Unit])
+
+      val tp = new TopicPartition(topic, 0)
+
+      val produceRequest = ProduceRequest.forCurrentMagic(new 
ProduceRequestData()
+        .setTopicData(new ProduceRequestData.TopicProduceDataCollection(
+          Collections.singletonList(new ProduceRequestData.TopicProduceData()
+            .setName(tp.topic).setPartitionData(Collections.singletonList(
+            new ProduceRequestData.PartitionProduceData()
+              .setIndex(tp.partition)
+              .setRecords(MemoryRecords.withRecords(CompressionType.NONE, new 
SimpleRecord("test".getBytes))))))
+            .iterator))
+        .setAcks(1.toShort)
+        .setTimeoutMs(5000))
+        .build(version.toShort)
+      val request = buildRequest(produceRequest)
+
+      when(replicaManager.appendRecords(anyLong,
+        anyShort,
+        ArgumentMatchers.eq(false),
+        ArgumentMatchers.eq(AppendOrigin.CLIENT),
+        any(),
+        responseCallback.capture(),
+        any(),
+        any(),
+        any(),
+        any(),
+        any())
+      ).thenAnswer(_ => responseCallback.getValue.apply(Map(tp -> new 
PartitionResponse(Errors.NOT_LEADER_OR_FOLLOWER))))
+
+      when(replicaManager.getPartitionOrError(tp)).thenAnswer(_ => 
Left(Errors.UNKNOWN_TOPIC_OR_PARTITION))
+
+      
when(clientRequestQuotaManager.maybeRecordAndGetThrottleTimeMs(any[RequestChannel.Request](),
+        any[Long])).thenReturn(0)
+      when(clientQuotaManager.maybeRecordAndGetThrottleTimeMs(
+        any[RequestChannel.Request](), anyDouble, anyLong)).thenReturn(0)
+
+      createKafkaApis().handleProduceRequest(request, 
RequestLocal.withThreadConfinedCaching)
+
+      val response = verifyNoThrottling[ProduceResponse](request)
+
+      assertEquals(1, response.data.responses.size)
+      val topicProduceResponse = response.data.responses.asScala.head
+      assertEquals(1, topicProduceResponse.partitionResponses.size)
+      val partitionProduceResponse = 
topicProduceResponse.partitionResponses.asScala.head
+      assertEquals(Errors.NOT_LEADER_OR_FOLLOWER, 
Errors.forCode(partitionProduceResponse.errorCode))
+      // LeaderId and epoch should be the same values inserted into the 
metadata cache
+      assertEquals(0, partitionProduceResponse.currentLeader.leaderId())
+      assertEquals(1, partitionProduceResponse.currentLeader.leaderEpoch())
+      assertEquals(1, response.data.nodeEndpoints.size)
+      val node = response.data.nodeEndpoints.asScala.head
+      assertEquals(0, node.nodeId)
+      assertEquals("broker0", node.host)
+    }
+  }
+
+  @Test
+  def testProduceResponseMetadataLookupErrorOnNotLeaderOrFollower(): Unit = {
+    val topic = "topic"
+    metadataCache = mock(classOf[ZkMetadataCache])
+
+    for (version <- 10 to ApiKeys.PRODUCE.latestVersion) {
+
+      reset(replicaManager, clientQuotaManager, clientRequestQuotaManager, 
requestChannel, txnCoordinator)
+
+      val responseCallback: ArgumentCaptor[Map[TopicPartition, 
PartitionResponse] => Unit] = 
ArgumentCaptor.forClass(classOf[Map[TopicPartition, PartitionResponse] => Unit])
+
+      val tp = new TopicPartition(topic, 0)
+
+      val produceRequest = ProduceRequest.forCurrentMagic(new 
ProduceRequestData()
+        .setTopicData(new ProduceRequestData.TopicProduceDataCollection(
+          Collections.singletonList(new ProduceRequestData.TopicProduceData()
+            .setName(tp.topic).setPartitionData(Collections.singletonList(
+            new ProduceRequestData.PartitionProduceData()
+              .setIndex(tp.partition)
+              .setRecords(MemoryRecords.withRecords(CompressionType.NONE, new 
SimpleRecord("test".getBytes))))))
+            .iterator))
+        .setAcks(1.toShort)
+        .setTimeoutMs(5000))
+        .build(version.toShort)
+      val request = buildRequest(produceRequest)
+
+      when(replicaManager.appendRecords(anyLong,
+        anyShort,
+        ArgumentMatchers.eq(false),
+        ArgumentMatchers.eq(AppendOrigin.CLIENT),
+        any(),
+        responseCallback.capture(),
+        any(),
+        any(),
+        any(),
+        any(),
+        any())
+      ).thenAnswer(_ => responseCallback.getValue.apply(Map(tp -> new 
PartitionResponse(Errors.NOT_LEADER_OR_FOLLOWER))))
+
+      when(replicaManager.getPartitionOrError(tp)).thenAnswer(_ => 
Left(Errors.UNKNOWN_TOPIC_OR_PARTITION))
+
+      
when(clientRequestQuotaManager.maybeRecordAndGetThrottleTimeMs(any[RequestChannel.Request](),
+        any[Long])).thenReturn(0)
+      when(clientQuotaManager.maybeRecordAndGetThrottleTimeMs(
+        any[RequestChannel.Request](), anyDouble, anyLong)).thenReturn(0)
+      when(metadataCache.contains(tp)).thenAnswer(_ => true)
+      when(metadataCache.getPartitionInfo(tp.topic(), 
tp.partition())).thenAnswer(_ => Option.empty)
+      when(metadataCache.getAliveBrokerNode(any(), 
any())).thenReturn(Option.empty)
+
+      createKafkaApis().handleProduceRequest(request, 
RequestLocal.withThreadConfinedCaching)
+
+      val response = verifyNoThrottling[ProduceResponse](request)
+
+      assertEquals(1, response.data.responses.size)
+      val topicProduceResponse = response.data.responses.asScala.head
+      assertEquals(1, topicProduceResponse.partitionResponses.size)
+      val partitionProduceResponse = 
topicProduceResponse.partitionResponses.asScala.head
+      assertEquals(Errors.NOT_LEADER_OR_FOLLOWER, 
Errors.forCode(partitionProduceResponse.errorCode))
+      assertEquals(-1, partitionProduceResponse.currentLeader.leaderId())
+      assertEquals(-1, partitionProduceResponse.currentLeader.leaderEpoch())
+      assertEquals(0, response.data.nodeEndpoints.size)
+    }
+  }
+
   @Test
   def testTransactionalParametersSetCorrectly(): Unit = {
     val topic = "topic"
@@ -3786,6 +3985,73 @@ class KafkaApisTest {
     assertEquals(MemoryRecords.EMPTY, 
FetchResponse.recordsOrFail(partitionData))
   }
 
+  @Test
+  def testFetchResponseContainsNewLeaderOnNotLeaderOrFollower(): Unit = {
+    val topicId = Uuid.randomUuid()
+    val tidp = new TopicIdPartition(topicId, new TopicPartition("foo", 0))
+    val tp = tidp.topicPartition
+    addTopicToMetadataCache(tp.topic, numPartitions = 1, numBrokers = 3, 
topicId)
+
+    
when(replicaManager.getLogConfig(ArgumentMatchers.eq(tp))).thenReturn(Some(LogConfig.fromProps(
+      Collections.emptyMap(),
+      new Properties()
+    )))
+
+    val partition = mock(classOf[Partition])
+    val newLeaderId = 2
+    val newLeaderEpoch = 5
+
+    when(replicaManager.getPartitionOrError(tp)).thenAnswer(_ => 
Right(partition))
+    when(partition.leaderReplicaIdOpt).thenAnswer(_ => Some(newLeaderId))
+    when(partition.getLeaderEpoch).thenAnswer(_ => newLeaderEpoch)
+
+    when(replicaManager.fetchMessages(
+      any[FetchParams],
+      any[Seq[(TopicIdPartition, FetchRequest.PartitionData)]],
+      any[ReplicaQuota],
+      any[Seq[(TopicIdPartition, FetchPartitionData)] => Unit]()
+    )).thenAnswer(invocation => {
+      val callback = 
invocation.getArgument(3).asInstanceOf[Seq[(TopicIdPartition, 
FetchPartitionData)] => Unit]
+      callback(Seq(tidp -> new 
FetchPartitionData(Errors.NOT_LEADER_OR_FOLLOWER, UnifiedLog.UnknownOffset, 
UnifiedLog.UnknownOffset, MemoryRecords.EMPTY,
+        Optional.empty(), OptionalLong.empty(), Optional.empty(), 
OptionalInt.empty(), false)))
+    })
+
+    val fetchData = Map(tidp -> new FetchRequest.PartitionData(Uuid.ZERO_UUID, 
0, 0, 1000,
+      Optional.empty())).asJava
+    val fetchDataBuilder = Map(tp -> new 
FetchRequest.PartitionData(Uuid.ZERO_UUID, 0, 0, 1000,
+      Optional.empty())).asJava
+    val fetchMetadata = new JFetchMetadata(0, 0)
+    val fetchContext = new FullFetchContext(time, new FetchSessionCache(1000, 
100),
+      fetchMetadata, fetchData, false, false)
+    when(fetchManager.newContext(
+      any[Short],
+      any[JFetchMetadata],
+      any[Boolean],
+      any[util.Map[TopicIdPartition, FetchRequest.PartitionData]],
+      any[util.List[TopicIdPartition]],
+      any[util.Map[Uuid, String]])).thenReturn(fetchContext)
+
+    when(clientQuotaManager.maybeRecordAndGetThrottleTimeMs(
+      any[RequestChannel.Request](), anyDouble, anyLong)).thenReturn(0)
+
+    val fetchRequest = new FetchRequest.Builder(16, 16, -1, -1, 100, 0, 
fetchDataBuilder)
+      .build()
+    val request = buildRequest(fetchRequest)
+
+    createKafkaApis().handleFetchRequest(request)
+
+    val response = verifyNoThrottling[FetchResponse](request)
+    val responseData = response.responseData(metadataCache.topicIdsToNames(), 
16)
+
+    val partitionData = responseData.get(tp)
+    assertEquals(Errors.NOT_LEADER_OR_FOLLOWER.code, partitionData.errorCode)
+    assertEquals(newLeaderId, partitionData.currentLeader.leaderId())
+    assertEquals(newLeaderEpoch, partitionData.currentLeader.leaderEpoch())
+    val node = response.data.nodeEndpoints.asScala.head
+    assertEquals(2, node.nodeId)
+    assertEquals("broker2", node.host)
+  }
+
   @ParameterizedTest
   @ApiKeyVersionsSource(apiKey = ApiKeys.JOIN_GROUP)
   def testHandleJoinGroupRequest(version: Short): Unit = {

Reply via email to