This is an automated email from the ASF dual-hosted git repository.

ijuma pushed a commit to branch 3.7
in repository https://gitbox.apache.org/repos/asf/kafka.git


The following commit(s) were added to refs/heads/3.7 by this push:
     new f31307aa223 KAFKA-16226; Reduce synchronization between producer 
threads (#15323) (#15493)
f31307aa223 is described below

commit f31307aa223c1f2f418a41bbf8f4e2100659e319
Author: Mayank Shekhar Narula <42991652+msn-t...@users.noreply.github.com>
AuthorDate: Thu Mar 14 13:46:11 2024 +0000

    KAFKA-16226; Reduce synchronization between producer threads (#15323) 
(#15493)
    
    As this [JIRA](https://issues.apache.org/jira/browse/KAFKA-16226) explains, 
there is increased synchronization between application-thread, and the 
background thread as the background thread started to synchronized methods 
Metadata.currentLeader() in [original 
PR](https://github.com/apache/kafka/pull/14384). So this PR does the following 
changes
    1. Changes background thread, i.e. RecordAccumulator's partitionReady(), 
and drainBatchesForOneNode(), to not use `Metadata.currentLeader()`. Instead 
rely on `MetadataCache` that is immutable. So access to it is unsynchronized.
    2.  This PR repurposes `MetadataCache` as an immutable snapshot of 
Metadata. This is a wrapper around public `Cluster`. `MetadataCache`'s 
API/functionality should be extended for internal client usage Vs public 
`Cluster`. For example, this PR adds `MetadataCache.leaderEpochFor()`
    3. Rename `MetadataCache` to `MetadataSnapshot` to make it explicit its 
immutable.
    
    **Note both `Cluster` and `MetadataCache` are not synchronized, hence 
reduce synchronization from the hot path for high partition counts.**
    
    Reviewers: Jason Gustafson <ja...@confluent.io>
---
 .../java/org/apache/kafka/clients/Metadata.java    |  65 +--
 .../{MetadataCache.java => MetadataSnapshot.java}  |  77 ++--
 .../clients/producer/internals/ProducerBatch.java  |  13 +-
 .../producer/internals/RecordAccumulator.java      |  78 ++--
 .../kafka/clients/producer/internals/Sender.java   |   6 +-
 .../kafka/common/requests/MetadataResponse.java    |   6 +-
 ...ataCacheTest.java => MetadataSnapshotTest.java} |  56 ++-
 .../org/apache/kafka/clients/MetadataTest.java     |   4 +-
 .../producer/internals/ProducerBatchTest.java      |  36 +-
 .../producer/internals/RecordAccumulatorTest.java  | 468 +++++++++------------
 .../clients/producer/internals/SenderTest.java     |   9 +-
 .../producer/internals/TransactionManagerTest.java |  33 +-
 .../test/java/org/apache/kafka/test/TestUtils.java |  39 ++
 13 files changed, 476 insertions(+), 414 deletions(-)

diff --git a/clients/src/main/java/org/apache/kafka/clients/Metadata.java 
b/clients/src/main/java/org/apache/kafka/clients/Metadata.java
index 607c74eeddb..30cad44a4bc 100644
--- a/clients/src/main/java/org/apache/kafka/clients/Metadata.java
+++ b/clients/src/main/java/org/apache/kafka/clients/Metadata.java
@@ -75,7 +75,7 @@ public class Metadata implements Closeable {
     private KafkaException fatalException;
     private Set<String> invalidTopics;
     private Set<String> unauthorizedTopics;
-    private MetadataCache cache = MetadataCache.empty();
+    private volatile MetadataSnapshot metadataSnapshot = 
MetadataSnapshot.empty();
     private boolean needFullUpdate;
     private boolean needPartialUpdate;
     private long equivalentResponseCount;
@@ -123,8 +123,15 @@ public class Metadata implements Closeable {
     /**
      * Get the current cluster info without blocking
      */
-    public synchronized Cluster fetch() {
-        return cache.cluster();
+    public Cluster fetch() {
+        return metadataSnapshot.cluster();
+    }
+
+    /**
+     * Get the current metadata cache.
+     */
+    public MetadataSnapshot fetchMetadataSnapshot() {
+        return metadataSnapshot;
     }
 
     /**
@@ -265,7 +272,7 @@ public class Metadata implements Closeable {
      */
     synchronized Optional<MetadataResponse.PartitionMetadata> 
partitionMetadataIfCurrent(TopicPartition topicPartition) {
         Integer epoch = lastSeenLeaderEpochs.get(topicPartition);
-        Optional<MetadataResponse.PartitionMetadata> partitionMetadata = 
cache.partitionMetadata(topicPartition);
+        Optional<MetadataResponse.PartitionMetadata> partitionMetadata = 
metadataSnapshot.partitionMetadata(topicPartition);
         if (epoch == null) {
             // old cluster format (no epochs)
             return partitionMetadata;
@@ -278,8 +285,8 @@ public class Metadata implements Closeable {
     /**
      * @return a mapping from topic names to topic IDs for all topics with 
valid IDs in the cache
      */
-    public synchronized Map<String, Uuid> topicIds() {
-        return cache.topicIds();
+    public Map<String, Uuid> topicIds() {
+        return metadataSnapshot.topicIds();
     }
 
     public synchronized LeaderAndEpoch currentLeader(TopicPartition 
topicPartition) {
@@ -289,14 +296,14 @@ public class Metadata implements Closeable {
 
         MetadataResponse.PartitionMetadata partitionMetadata = 
maybeMetadata.get();
         Optional<Integer> leaderEpochOpt = partitionMetadata.leaderEpoch;
-        Optional<Node> leaderNodeOpt = 
partitionMetadata.leaderId.flatMap(cache::nodeById);
+        Optional<Node> leaderNodeOpt = 
partitionMetadata.leaderId.flatMap(metadataSnapshot::nodeById);
         return new LeaderAndEpoch(leaderNodeOpt, leaderEpochOpt);
     }
 
     public synchronized void bootstrap(List<InetSocketAddress> addresses) {
         this.needFullUpdate = true;
         this.updateVersion += 1;
-        this.cache = MetadataCache.bootstrap(addresses);
+        this.metadataSnapshot = MetadataSnapshot.bootstrap(addresses);
     }
 
     /**
@@ -335,22 +342,22 @@ public class Metadata implements Closeable {
         // this count is reset to 0 in updateLatestMetadata()
         this.equivalentResponseCount++;
 
-        String previousClusterId = cache.clusterResource().clusterId();
+        String previousClusterId = 
metadataSnapshot.clusterResource().clusterId();
 
-        this.cache = handleMetadataResponse(response, isPartialUpdate, nowMs);
+        this.metadataSnapshot = handleMetadataResponse(response, 
isPartialUpdate, nowMs);
 
-        Cluster cluster = cache.cluster();
+        Cluster cluster = metadataSnapshot.cluster();
         maybeSetMetadataError(cluster);
 
         this.lastSeenLeaderEpochs.keySet().removeIf(tp -> 
!retainTopic(tp.topic(), false, nowMs));
 
-        String newClusterId = cache.clusterResource().clusterId();
+        String newClusterId = metadataSnapshot.clusterResource().clusterId();
         if (!Objects.equals(previousClusterId, newClusterId)) {
             log.info("Cluster ID: {}", newClusterId);
         }
-        clusterResourceListeners.onUpdate(cache.clusterResource());
+        clusterResourceListeners.onUpdate(metadataSnapshot.clusterResource());
 
-        log.debug("Updated cluster metadata updateVersion {} to {}", 
this.updateVersion, this.cache);
+        log.debug("Updated cluster metadata updateVersion {} to {}", 
this.updateVersion, this.metadataSnapshot);
     }
 
     /**
@@ -365,7 +372,7 @@ public class Metadata implements Closeable {
     public synchronized Set<TopicPartition> 
updatePartitionLeadership(Map<TopicPartition, LeaderIdAndEpoch> 
partitionLeaders, List<Node> leaderNodes) {
         Map<Integer, Node> newNodes = 
leaderNodes.stream().collect(Collectors.toMap(Node::id, node -> node));
         // Insert non-overlapping nodes from existing-nodes into new-nodes.
-        this.cache.cluster().nodes().stream().forEach(node -> 
newNodes.putIfAbsent(node.id(), node));
+        this.metadataSnapshot.cluster().nodes().stream().forEach(node -> 
newNodes.putIfAbsent(node.id(), node));
 
         // Create partition-metadata for all updated partitions. Exclude 
updates for partitions -
         // 1. for which the corresponding partition has newer leader in 
existing metadata.
@@ -388,12 +395,12 @@ public class Metadata implements Closeable {
                 log.debug("For {}, incoming leader({}), the corresponding node 
information for node-id {} is missing, so ignoring.", partition, newLeader, 
newLeader.leaderId.get());
                 continue;
             }
-            if (!this.cache.partitionMetadata(partition).isPresent()) {
+            if 
(!this.metadataSnapshot.partitionMetadata(partition).isPresent()) {
                 log.debug("For {}, incoming leader({}), partition metadata is 
no longer cached, ignoring.", partition, newLeader);
                 continue;
             }
 
-            MetadataResponse.PartitionMetadata existingMetadata = 
this.cache.partitionMetadata(partition).get();
+            MetadataResponse.PartitionMetadata existingMetadata = 
this.metadataSnapshot.partitionMetadata(partition).get();
             MetadataResponse.PartitionMetadata updatedMetadata = new 
MetadataResponse.PartitionMetadata(
                 existingMetadata.error,
                 partition,
@@ -416,7 +423,7 @@ public class Metadata implements Closeable {
         Set<String> updatedTopics = 
updatePartitionMetadata.stream().map(MetadataResponse.PartitionMetadata::topic).collect(Collectors.toSet());
 
         // Get topic-ids for updated topics from existing topic-ids.
-        Map<String, Uuid> existingTopicIds = this.cache.topicIds();
+        Map<String, Uuid> existingTopicIds = this.metadataSnapshot.topicIds();
         Map<String, Uuid> topicIdsForUpdatedTopics = updatedTopics.stream()
             .filter(e -> existingTopicIds.containsKey(e))
             .collect(Collectors.toMap(e -> e, e -> existingTopicIds.get(e)));
@@ -429,15 +436,15 @@ public class Metadata implements Closeable {
 
         // Fetch responses can include partition level leader changes, when 
this happens, we perform a partial
         // metadata update, by keeping the unchanged partition and update the 
changed partitions.
-        this.cache = cache.mergeWith(
-            cache.clusterResource().clusterId(),
+        this.metadataSnapshot = metadataSnapshot.mergeWith(
+            metadataSnapshot.clusterResource().clusterId(),
             newNodes,
             updatePartitionMetadata,
             Collections.emptySet(), Collections.emptySet(), 
Collections.emptySet(),
-            cache.cluster().controller(),
+            metadataSnapshot.cluster().controller(),
             topicIdsForUpdatedTopics,
             (topic, isInternal) -> true);
-        clusterResourceListeners.onUpdate(cache.clusterResource());
+        clusterResourceListeners.onUpdate(metadataSnapshot.clusterResource());
 
         return updatePartitionMetadata.stream()
             .map(metadata -> metadata.topicPartition)
@@ -467,7 +474,7 @@ public class Metadata implements Closeable {
     /**
      * Transform a MetadataResponse into a new MetadataCache instance.
      */
-    private MetadataCache handleMetadataResponse(MetadataResponse 
metadataResponse, boolean isPartialUpdate, long nowMs) {
+    private MetadataSnapshot handleMetadataResponse(MetadataResponse 
metadataResponse, boolean isPartialUpdate, long nowMs) {
         // All encountered topics.
         Set<String> topics = new HashSet<>();
 
@@ -478,7 +485,7 @@ public class Metadata implements Closeable {
 
         List<MetadataResponse.PartitionMetadata> partitions = new 
ArrayList<>();
         Map<String, Uuid> topicIds = new HashMap<>();
-        Map<String, Uuid> oldTopicIds = cache.topicIds();
+        Map<String, Uuid> oldTopicIds = metadataSnapshot.topicIds();
         for (MetadataResponse.TopicMetadata metadata : 
metadataResponse.topicMetadata()) {
             String topicName = metadata.topic();
             Uuid topicId = metadata.topicId();
@@ -526,11 +533,11 @@ public class Metadata implements Closeable {
 
         Map<Integer, Node> nodes = metadataResponse.brokersById();
         if (isPartialUpdate)
-            return this.cache.mergeWith(metadataResponse.clusterId(), nodes, 
partitions,
+            return 
this.metadataSnapshot.mergeWith(metadataResponse.clusterId(), nodes, partitions,
                 unauthorizedTopics, invalidTopics, internalTopics, 
metadataResponse.controller(), topicIds,
                 (topic, isInternal) -> !topics.contains(topic) && 
retainTopic(topic, isInternal, nowMs));
         else
-            return new MetadataCache(metadataResponse.clusterId(), nodes, 
partitions,
+            return new MetadataSnapshot(metadataResponse.clusterId(), nodes, 
partitions,
                 unauthorizedTopics, invalidTopics, internalTopics, 
metadataResponse.controller(), topicIds);
     }
 
@@ -575,7 +582,7 @@ public class Metadata implements Closeable {
             } else {
                 // Otherwise ignore the new metadata and use the previously 
cached info
                 log.debug("Got metadata for an older epoch {} (current is {}) 
for partition {}, not updating", newEpoch, currentEpoch, tp);
-                return cache.partitionMetadata(tp);
+                return metadataSnapshot.partitionMetadata(tp);
             }
         } else {
             // Handle old cluster formats as well as error responses where 
leader and epoch are missing
@@ -738,8 +745,8 @@ public class Metadata implements Closeable {
     /**
      * @return Mapping from topic IDs to topic names for all topics in the 
cache.
      */
-    public synchronized Map<Uuid, String> topicNames() {
-        return cache.topicNames();
+    public Map<Uuid, String> topicNames() {
+        return metadataSnapshot.topicNames();
     }
 
     protected boolean retainTopic(String topic, boolean isInternal, long 
nowMs) {
diff --git a/clients/src/main/java/org/apache/kafka/clients/MetadataCache.java 
b/clients/src/main/java/org/apache/kafka/clients/MetadataSnapshot.java
similarity index 77%
rename from clients/src/main/java/org/apache/kafka/clients/MetadataCache.java
rename to clients/src/main/java/org/apache/kafka/clients/MetadataSnapshot.java
index 45574c3549c..fbfb828b31c 100644
--- a/clients/src/main/java/org/apache/kafka/clients/MetadataCache.java
+++ b/clients/src/main/java/org/apache/kafka/clients/MetadataSnapshot.java
@@ -16,6 +16,7 @@
  */
 package org.apache.kafka.clients;
 
+import java.util.OptionalInt;
 import org.apache.kafka.common.Cluster;
 import org.apache.kafka.common.ClusterResource;
 import org.apache.kafka.common.Node;
@@ -39,10 +40,11 @@ import java.util.function.Predicate;
 import java.util.stream.Collectors;
 
 /**
- * An internal mutable cache of nodes, topics, and partitions in the Kafka 
cluster. This keeps an up-to-date Cluster
+ * An internal immutable snapshot of nodes, topics, and partitions in the 
Kafka cluster. This keeps an up-to-date Cluster
  * instance which is optimized for read access.
+ * Prefer to extend MetadataSnapshot's API for internal client usage Vs the 
public {@link Cluster}
  */
-public class MetadataCache {
+public class MetadataSnapshot {
     private final String clusterId;
     private final Map<Integer, Node> nodes;
     private final Set<String> unauthorizedTopics;
@@ -54,7 +56,7 @@ public class MetadataCache {
     private final Map<Uuid, String> topicNames;
     private Cluster clusterInstance;
 
-    MetadataCache(String clusterId,
+    public MetadataSnapshot(String clusterId,
                   Map<Integer, Node> nodes,
                   Collection<PartitionMetadata> partitions,
                   Set<String> unauthorizedTopics,
@@ -65,30 +67,32 @@ public class MetadataCache {
         this(clusterId, nodes, partitions, unauthorizedTopics, invalidTopics, 
internalTopics, controller, topicIds, null);
     }
 
-    private MetadataCache(String clusterId,
-                          Map<Integer, Node> nodes,
-                          Collection<PartitionMetadata> partitions,
-                          Set<String> unauthorizedTopics,
-                          Set<String> invalidTopics,
-                          Set<String> internalTopics,
-                          Node controller,
-                          Map<String, Uuid> topicIds,
-                          Cluster clusterInstance) {
+    // Visible for testing
+    public MetadataSnapshot(String clusterId,
+        Map<Integer, Node> nodes,
+        Collection<PartitionMetadata> partitions,
+        Set<String> unauthorizedTopics,
+        Set<String> invalidTopics,
+        Set<String> internalTopics,
+        Node controller,
+        Map<String, Uuid> topicIds,
+        Cluster clusterInstance) {
         this.clusterId = clusterId;
-        this.nodes = nodes;
-        this.unauthorizedTopics = unauthorizedTopics;
-        this.invalidTopics = invalidTopics;
-        this.internalTopics = internalTopics;
+        this.nodes = Collections.unmodifiableMap(nodes);
+        this.unauthorizedTopics = 
Collections.unmodifiableSet(unauthorizedTopics);
+        this.invalidTopics = Collections.unmodifiableSet(invalidTopics);
+        this.internalTopics = Collections.unmodifiableSet(internalTopics);
         this.controller = controller;
         this.topicIds = Collections.unmodifiableMap(topicIds);
         this.topicNames = Collections.unmodifiableMap(
             
topicIds.entrySet().stream().collect(Collectors.toMap(Map.Entry::getValue, 
Map.Entry::getKey))
         );
 
-        this.metadataByPartition = new HashMap<>(partitions.size());
+        Map<TopicPartition, PartitionMetadata> tmpMetadataByPartition = new 
HashMap<>(partitions.size());
         for (PartitionMetadata p : partitions) {
-            this.metadataByPartition.put(p.topicPartition, p);
+            tmpMetadataByPartition.put(p.topicPartition, p);
         }
+        this.metadataByPartition = 
Collections.unmodifiableMap(tmpMetadataByPartition);
 
         if (clusterInstance == null) {
             computeClusterView();
@@ -113,7 +117,7 @@ public class MetadataCache {
         return Optional.ofNullable(nodes.get(id));
     }
 
-    Cluster cluster() {
+    public Cluster cluster() {
         if (clusterInstance == null) {
             throw new IllegalStateException("Cached Cluster instance should 
not be null, but was.");
         } else {
@@ -121,13 +125,28 @@ public class MetadataCache {
         }
     }
 
+    /**
+     * Get leader-epoch for partition.
+     *
+     * @param tp partition
+     * @return leader-epoch if known, else return OptionalInt.empty()
+     */
+    public OptionalInt leaderEpochFor(TopicPartition tp) {
+        PartitionMetadata partitionMetadata = metadataByPartition.get(tp);
+        if (partitionMetadata == null || 
!partitionMetadata.leaderEpoch.isPresent()) {
+            return OptionalInt.empty();
+        } else {
+            return OptionalInt.of(partitionMetadata.leaderEpoch.get());
+        }
+    }
+
     ClusterResource clusterResource() {
         return new ClusterResource(clusterId);
     }
 
     /**
-     * Merges the metadata cache's contents with the provided metadata, 
returning a new metadata cache. The provided
-     * metadata is presumed to be more recent than the cache's metadata, and 
therefore all overlapping metadata will
+     * Merges the metadata snapshot's contents with the provided metadata, 
returning a new metadata snapshot. The provided
+     * metadata is presumed to be more recent than the snapshot's metadata, 
and therefore all overlapping metadata will
      * be overridden.
      *
      * @param newClusterId the new cluster Id
@@ -138,9 +157,9 @@ public class MetadataCache {
      * @param newController the new controller node
      * @param addTopicIds the mapping from topic name to topic ID, for topics 
in addPartitions
      * @param retainTopic returns whether a pre-existing topic's metadata 
should be retained
-     * @return the merged metadata cache
+     * @return the merged metadata snapshot
      */
-    MetadataCache mergeWith(String newClusterId,
+    MetadataSnapshot mergeWith(String newClusterId,
                             Map<Integer, Node> newNodes,
                             Collection<PartitionMetadata> addPartitions,
                             Set<String> addUnauthorizedTopics,
@@ -180,7 +199,7 @@ public class MetadataCache {
         Set<String> newInvalidTopics = fillSet(addInvalidTopics, 
invalidTopics, shouldRetainTopic);
         Set<String> newInternalTopics = fillSet(addInternalTopics, 
internalTopics, shouldRetainTopic);
 
-        return new MetadataCache(newClusterId, newNodes, 
newMetadataByPartition.values(), newUnauthorizedTopics,
+        return new MetadataSnapshot(newClusterId, newNodes, 
newMetadataByPartition.values(), newUnauthorizedTopics,
                 newInvalidTopics, newInternalTopics, newController, 
newTopicIds);
     }
 
@@ -212,26 +231,26 @@ public class MetadataCache {
                 invalidTopics, internalTopics, controller, topicIds);
     }
 
-    static MetadataCache bootstrap(List<InetSocketAddress> addresses) {
+    static MetadataSnapshot bootstrap(List<InetSocketAddress> addresses) {
         Map<Integer, Node> nodes = new HashMap<>();
         int nodeId = -1;
         for (InetSocketAddress address : addresses) {
             nodes.put(nodeId, new Node(nodeId, address.getHostString(), 
address.getPort()));
             nodeId--;
         }
-        return new MetadataCache(null, nodes, Collections.emptyList(),
+        return new MetadataSnapshot(null, nodes, Collections.emptyList(),
                 Collections.emptySet(), Collections.emptySet(), 
Collections.emptySet(),
                 null, Collections.emptyMap(), Cluster.bootstrap(addresses));
     }
 
-    static MetadataCache empty() {
-        return new MetadataCache(null, Collections.emptyMap(), 
Collections.emptyList(),
+    static MetadataSnapshot empty() {
+        return new MetadataSnapshot(null, Collections.emptyMap(), 
Collections.emptyList(),
                 Collections.emptySet(), Collections.emptySet(), 
Collections.emptySet(), null, Collections.emptyMap(), Cluster.empty());
     }
 
     @Override
     public String toString() {
-        return "MetadataCache{" +
+        return "MetadataSnapshot{" +
                 "clusterId='" + clusterId + '\'' +
                 ", nodes=" + nodes +
                 ", partitions=" + metadataByPartition.values() +
diff --git 
a/clients/src/main/java/org/apache/kafka/clients/producer/internals/ProducerBatch.java
 
b/clients/src/main/java/org/apache/kafka/clients/producer/internals/ProducerBatch.java
index 61432b53ab0..391cc1b3441 100644
--- 
a/clients/src/main/java/org/apache/kafka/clients/producer/internals/ProducerBatch.java
+++ 
b/clients/src/main/java/org/apache/kafka/clients/producer/internals/ProducerBatch.java
@@ -16,7 +16,7 @@
  */
 package org.apache.kafka.clients.producer.internals;
 
-import java.util.Optional;
+import java.util.OptionalInt;
 import org.apache.kafka.clients.producer.Callback;
 import org.apache.kafka.clients.producer.RecordMetadata;
 import org.apache.kafka.common.TopicPartition;
@@ -81,7 +81,7 @@ public final class ProducerBatch {
     private boolean reopened;
 
     // Tracks the current-leader's epoch to which this batch would be sent, in 
the current to produce the batch.
-    private Optional<Integer> currentLeaderEpoch;
+    private OptionalInt currentLeaderEpoch;
     // Tracks the attempt in which leader was changed to currentLeaderEpoch 
for the 1st time.
     private int attemptsWhenLeaderLastChanged;
 
@@ -100,7 +100,7 @@ public final class ProducerBatch {
         this.isSplitBatch = isSplitBatch;
         float compressionRatioEstimation = 
CompressionRatioEstimator.estimation(topicPartition.topic(),
                                                                                
 recordsBuilder.compressionType());
-        this.currentLeaderEpoch = Optional.empty();
+        this.currentLeaderEpoch = OptionalInt.empty();
         this.attemptsWhenLeaderLastChanged = 0;
         
recordsBuilder.setEstimatedCompressionRatio(compressionRatioEstimation);
     }
@@ -109,8 +109,9 @@ public final class ProducerBatch {
      * It will update the leader to which this batch will be produced for the 
ongoing attempt, if a newer leader is known.
      * @param latestLeaderEpoch latest leader's epoch.
      */
-    void maybeUpdateLeaderEpoch(Optional<Integer> latestLeaderEpoch) {
-        if (!currentLeaderEpoch.equals(latestLeaderEpoch)) {
+    void maybeUpdateLeaderEpoch(OptionalInt latestLeaderEpoch) {
+        if (latestLeaderEpoch.isPresent()
+            && (!currentLeaderEpoch.isPresent() || 
currentLeaderEpoch.getAsInt() < latestLeaderEpoch.getAsInt())) {
             log.trace("For {}, leader will be updated, currentLeaderEpoch: {}, 
attemptsWhenLeaderLastChanged:{}, latestLeaderEpoch: {}, current attempt: {}",
                 this, currentLeaderEpoch, attemptsWhenLeaderLastChanged, 
latestLeaderEpoch, attempts);
             attemptsWhenLeaderLastChanged = attempts();
@@ -558,7 +559,7 @@ public final class ProducerBatch {
     }
 
     // VisibleForTesting
-    Optional<Integer> currentLeaderEpoch() {
+    OptionalInt currentLeaderEpoch() {
         return currentLeaderEpoch;
     }
 
diff --git 
a/clients/src/main/java/org/apache/kafka/clients/producer/internals/RecordAccumulator.java
 
b/clients/src/main/java/org/apache/kafka/clients/producer/internals/RecordAccumulator.java
index 887e789edfc..faef3a759a1 100644
--- 
a/clients/src/main/java/org/apache/kafka/clients/producer/internals/RecordAccumulator.java
+++ 
b/clients/src/main/java/org/apache/kafka/clients/producer/internals/RecordAccumulator.java
@@ -26,13 +26,14 @@ import java.util.HashSet;
 import java.util.List;
 import java.util.Map;
 import java.util.Objects;
+import java.util.OptionalInt;
 import java.util.Set;
 import java.util.concurrent.ConcurrentMap;
 import java.util.concurrent.atomic.AtomicInteger;
 
 import org.apache.kafka.clients.ApiVersions;
 import org.apache.kafka.clients.CommonClientConfigs;
-import org.apache.kafka.clients.Metadata;
+import org.apache.kafka.clients.MetadataSnapshot;
 import org.apache.kafka.clients.producer.Callback;
 import org.apache.kafka.clients.producer.RecordMetadata;
 import org.apache.kafka.common.utils.ExponentialBackoff;
@@ -647,27 +648,27 @@ public class RecordAccumulator {
     }
 
     /**
-     * Iterate over partitions to see which one have batches ready and collect 
leaders of those partitions
-     * into the set of ready nodes.  If partition has no leader, add the topic 
to the set of topics with
-     * no leader.  This function also calculates stats for adaptive 
partitioning.
+     * Iterate over partitions to see which one have batches ready and collect 
leaders of those
+     * partitions into the set of ready nodes.  If partition has no leader, 
add the topic to the set
+     * of topics with no leader.  This function also calculates stats for 
adaptive partitioning.
      *
-     * @param metadata The cluster metadata
-     * @param nowMs The current time
-     * @param topic The topic
-     * @param topicInfo The topic info
+     * @param metadataSnapshot      The cluster metadata
+     * @param nowMs                 The current time
+     * @param topic                 The topic
+     * @param topicInfo             The topic info
      * @param nextReadyCheckDelayMs The delay for next check
-     * @param readyNodes The set of ready nodes (to be filled in)
-     * @param unknownLeaderTopics The set of topics with no leader (to be 
filled in)
+     * @param readyNodes            The set of ready nodes (to be filled in)
+     * @param unknownLeaderTopics   The set of topics with no leader (to be 
filled in)
      * @return The delay for next check
      */
-    private long partitionReady(Metadata metadata, long nowMs, String topic,
+    private long partitionReady(MetadataSnapshot metadataSnapshot, long nowMs, 
String topic,
                                 TopicInfo topicInfo,
                                 long nextReadyCheckDelayMs, Set<Node> 
readyNodes, Set<String> unknownLeaderTopics) {
         ConcurrentMap<Integer, Deque<ProducerBatch>> batches = 
topicInfo.batches;
         // Collect the queue sizes for available partitions to be used in 
adaptive partitioning.
         int[] queueSizes = null;
         int[] partitionIds = null;
-        if (enableAdaptivePartitioning && batches.size() >= 
metadata.fetch().partitionsForTopic(topic).size()) {
+        if (enableAdaptivePartitioning && batches.size() >= 
metadataSnapshot.cluster().partitionsForTopic(topic).size()) {
             // We don't do adaptive partitioning until we scheduled at least a 
batch for all
             // partitions (i.e. we have the corresponding entries in the 
batches map), we just
             // do uniform.  The reason is that we build queue sizes from the 
batches map,
@@ -684,8 +685,7 @@ public class RecordAccumulator {
             // Advance queueSizesIndex so that we properly index available
             // partitions.  Do it here so that it's done for all code paths.
 
-            Metadata.LeaderAndEpoch leaderAndEpoch = 
metadata.currentLeader(part);
-            Node leader = leaderAndEpoch.leader.orElse(null);
+            Node leader = metadataSnapshot.cluster().leaderFor(part);
             if (leader != null && queueSizes != null) {
                 ++queueSizesIndex;
                 assert queueSizesIndex < queueSizes.length;
@@ -700,9 +700,14 @@ public class RecordAccumulator {
             final int dequeSize;
             final boolean full;
 
-            // This loop is especially hot with large partition counts.
+            OptionalInt leaderEpoch = metadataSnapshot.leaderEpochFor(part);
+
+            // This loop is especially hot with large partition counts. So -
 
-            // We are careful to only perform the minimum required inside the
+            // 1. We should avoid code that increases synchronization between 
application thread calling
+            // send(), and background thread running runOnce(), see 
https://issues.apache.org/jira/browse/KAFKA-16226
+
+            // 2. We are careful to only perform the minimum required inside 
the
             // synchronized block, as this lock is also used to synchronize 
producer threads
             // attempting to append() to a partition/batch.
 
@@ -715,7 +720,7 @@ public class RecordAccumulator {
                 }
 
                 waitedTimeMs = batch.waitedTimeMs(nowMs);
-                batch.maybeUpdateLeaderEpoch(leaderAndEpoch.epoch);
+                batch.maybeUpdateLeaderEpoch(leaderEpoch);
                 backingOff = 
shouldBackoff(batch.hasLeaderChangedForTheOngoingRetry(), batch, waitedTimeMs);
                 backoffAttempts = batch.attempts();
                 dequeSize = deque.size();
@@ -776,7 +781,7 @@ public class RecordAccumulator {
      * </ul>
      * </ol>
      */
-    public ReadyCheckResult ready(Metadata metadata, long nowMs) {
+    public ReadyCheckResult ready(MetadataSnapshot metadataSnapshot, long 
nowMs) {
         Set<Node> readyNodes = new HashSet<>();
         long nextReadyCheckDelayMs = Long.MAX_VALUE;
         Set<String> unknownLeaderTopics = new HashSet<>();
@@ -784,7 +789,7 @@ public class RecordAccumulator {
         // cumulative frequency table (used in partitioner).
         for (Map.Entry<String, TopicInfo> topicInfoEntry : 
this.topicInfoMap.entrySet()) {
             final String topic = topicInfoEntry.getKey();
-            nextReadyCheckDelayMs = partitionReady(metadata, nowMs, topic, 
topicInfoEntry.getValue(), nextReadyCheckDelayMs, readyNodes, 
unknownLeaderTopics);
+            nextReadyCheckDelayMs = partitionReady(metadataSnapshot, nowMs, 
topic, topicInfoEntry.getValue(), nextReadyCheckDelayMs, readyNodes, 
unknownLeaderTopics);
         }
         return new ReadyCheckResult(readyNodes, nextReadyCheckDelayMs, 
unknownLeaderTopics);
     }
@@ -861,9 +866,9 @@ public class RecordAccumulator {
         return false;
     }
 
-    private List<ProducerBatch> drainBatchesForOneNode(Metadata metadata, Node 
node, int maxSize, long now) {
+    private List<ProducerBatch> drainBatchesForOneNode(MetadataSnapshot 
metadataSnapshot, Node node, int maxSize, long now) {
         int size = 0;
-        List<PartitionInfo> parts = 
metadata.fetch().partitionsForNode(node.id());
+        List<PartitionInfo> parts = 
metadataSnapshot.cluster().partitionsForNode(node.id());
         List<ProducerBatch> ready = new ArrayList<>();
         if (parts.isEmpty())
             return ready;
@@ -879,17 +884,12 @@ public class RecordAccumulator {
             // Only proceed if the partition has no in-flight batches.
             if (isMuted(tp))
                 continue;
-            Metadata.LeaderAndEpoch leaderAndEpoch = 
metadata.currentLeader(tp);
-            // Although a small chance, but skip this partition if leader has 
changed since the partition -> node assignment obtained from outside the loop.
-            // In this case, skip sending it to the old leader, as it would 
return aa NO_LEADER_OR_FOLLOWER error.
-            if (!leaderAndEpoch.leader.isPresent())
-                continue;
-            if (!node.equals(leaderAndEpoch.leader.get()))
-                continue;
             Deque<ProducerBatch> deque = getDeque(tp);
             if (deque == null)
                 continue;
 
+            OptionalInt leaderEpoch = metadataSnapshot.leaderEpochFor(tp);
+
             final ProducerBatch batch;
             synchronized (deque) {
                 // invariant: !isMuted(tp,now) && deque != null
@@ -899,7 +899,7 @@ public class RecordAccumulator {
 
                 // first != null
                 // Only drain the batch if it is not during backoff period.
-                first.maybeUpdateLeaderEpoch(leaderAndEpoch.epoch);
+                first.maybeUpdateLeaderEpoch(leaderEpoch);
                 if (shouldBackoff(first.hasLeaderChangedForTheOngoingRetry(), 
first, first.waitedTimeMs(now)))
                     continue;
 
@@ -963,22 +963,24 @@ public class RecordAccumulator {
     }
 
     /**
-     * Drain all the data for the given nodes and collate them into a list of 
batches that will fit within the specified
-     * size on a per-node basis. This method attempts to avoid choosing the 
same topic-node over and over.
+     * Drain all the data for the given nodes and collate them into a list of 
batches that will fit
+     * within the specified size on a per-node basis. This method attempts to 
avoid choosing the same
+     * topic-node over and over.
      *
-     * @param metadata The current cluster metadata
-     * @param nodes The list of node to drain
-     * @param maxSize The maximum number of bytes to drain
-     * @param now The current unix time in milliseconds
-     * @return A list of {@link ProducerBatch} for each node specified with 
total size less than the requested maxSize.
+     * @param metadataSnapshot  The current cluster metadata
+     * @param nodes             The list of node to drain
+     * @param maxSize           The maximum number of bytes to drain
+     * @param now               The current unix time in milliseconds
+     * @return A list of {@link ProducerBatch} for each node specified with 
total size less than the
+     * requested maxSize.
      */
-    public Map<Integer, List<ProducerBatch>> drain(Metadata metadata, 
Set<Node> nodes, int maxSize, long now) {
+    public Map<Integer, List<ProducerBatch>> drain(MetadataSnapshot 
metadataSnapshot, Set<Node> nodes, int maxSize, long now) {
         if (nodes.isEmpty())
             return Collections.emptyMap();
 
         Map<Integer, List<ProducerBatch>> batches = new HashMap<>();
         for (Node node : nodes) {
-            List<ProducerBatch> ready = drainBatchesForOneNode(metadata, node, 
maxSize, now);
+            List<ProducerBatch> ready = 
drainBatchesForOneNode(metadataSnapshot, node, maxSize, now);
             batches.put(node.id(), ready);
         }
         return batches;
diff --git 
a/clients/src/main/java/org/apache/kafka/clients/producer/internals/Sender.java 
b/clients/src/main/java/org/apache/kafka/clients/producer/internals/Sender.java
index 57889b2591e..461937b00e1 100644
--- 
a/clients/src/main/java/org/apache/kafka/clients/producer/internals/Sender.java
+++ 
b/clients/src/main/java/org/apache/kafka/clients/producer/internals/Sender.java
@@ -25,6 +25,7 @@ import org.apache.kafka.clients.ClientRequest;
 import org.apache.kafka.clients.ClientResponse;
 import org.apache.kafka.clients.KafkaClient;
 import org.apache.kafka.clients.Metadata;
+import org.apache.kafka.clients.MetadataSnapshot;
 import org.apache.kafka.clients.NetworkClientUtils;
 import org.apache.kafka.clients.RequestCompletionHandler;
 import org.apache.kafka.common.InvalidRecordException;
@@ -364,8 +365,9 @@ public class Sender implements Runnable {
     }
 
     private long sendProducerData(long now) {
+        MetadataSnapshot metadataSnapshot = metadata.fetchMetadataSnapshot();
         // get the list of partitions with data ready to send
-        RecordAccumulator.ReadyCheckResult result = 
this.accumulator.ready(metadata, now);
+        RecordAccumulator.ReadyCheckResult result = 
this.accumulator.ready(metadataSnapshot, now);
 
         // if there are any partitions whose leaders are not known yet, force 
metadata update
         if (!result.unknownLeaderTopics.isEmpty()) {
@@ -400,7 +402,7 @@ public class Sender implements Runnable {
         }
 
         // create produce requests
-        Map<Integer, List<ProducerBatch>> batches = 
this.accumulator.drain(metadata, result.readyNodes, this.maxRequestSize, now);
+        Map<Integer, List<ProducerBatch>> batches = 
this.accumulator.drain(metadataSnapshot, result.readyNodes, 
this.maxRequestSize, now);
         addToInflightBatches(batches);
         if (guaranteeMessageOrder) {
             // Mute all the partitions drained
diff --git 
a/clients/src/main/java/org/apache/kafka/common/requests/MetadataResponse.java 
b/clients/src/main/java/org/apache/kafka/common/requests/MetadataResponse.java
index 47cdd3f0d7e..bc6c387c5e0 100644
--- 
a/clients/src/main/java/org/apache/kafka/common/requests/MetadataResponse.java
+++ 
b/clients/src/main/java/org/apache/kafka/common/requests/MetadataResponse.java
@@ -172,9 +172,9 @@ public class MetadataResponse extends AbstractResponse {
         return new PartitionInfo(metadata.topic(),
                 metadata.partition(),
                 metadata.leaderId.map(nodesById::get).orElse(null),
-                convertToNodeArray(metadata.replicaIds, nodesById),
-                convertToNodeArray(metadata.inSyncReplicaIds, nodesById),
-                convertToNodeArray(metadata.offlineReplicaIds, nodesById));
+                (metadata.replicaIds == null) ? null : 
convertToNodeArray(metadata.replicaIds, nodesById),
+                (metadata.inSyncReplicaIds == null) ? null : 
convertToNodeArray(metadata.inSyncReplicaIds, nodesById),
+                (metadata.offlineReplicaIds == null) ? null : 
convertToNodeArray(metadata.offlineReplicaIds, nodesById));
     }
 
     private static Node[] convertToNodeArray(List<Integer> replicaIds, 
Map<Integer, Node> nodesById) {
diff --git 
a/clients/src/test/java/org/apache/kafka/clients/MetadataCacheTest.java 
b/clients/src/test/java/org/apache/kafka/clients/MetadataSnapshotTest.java
similarity index 77%
rename from 
clients/src/test/java/org/apache/kafka/clients/MetadataCacheTest.java
rename to 
clients/src/test/java/org/apache/kafka/clients/MetadataSnapshotTest.java
index f99eb04dfcb..1012709926a 100644
--- a/clients/src/test/java/org/apache/kafka/clients/MetadataCacheTest.java
+++ b/clients/src/test/java/org/apache/kafka/clients/MetadataSnapshotTest.java
@@ -16,6 +16,7 @@
  */
 package org.apache.kafka.clients;
 
+import java.util.OptionalInt;
 import org.apache.kafka.common.Cluster;
 import org.apache.kafka.common.Node;
 import org.apache.kafka.common.PartitionInfo;
@@ -37,7 +38,7 @@ import static org.junit.jupiter.api.Assertions.assertEquals;
 import static org.junit.jupiter.api.Assertions.assertNull;
 import static org.junit.jupiter.api.Assertions.assertTrue;
 
-public class MetadataCacheTest {
+public class MetadataSnapshotTest {
 
     @Test
     public void testMissingLeaderEndpoint() {
@@ -62,7 +63,7 @@ public class MetadataCacheTest {
         nodesById.put(7, new Node(7, "localhost", 2078));
         nodesById.put(8, new Node(8, "localhost", 2079));
 
-        MetadataCache cache = new MetadataCache("clusterId",
+        MetadataSnapshot cache = new MetadataSnapshot("clusterId",
                 nodesById,
                 Collections.singleton(partitionMetadata),
                 Collections.emptySet(),
@@ -107,7 +108,7 @@ public class MetadataCacheTest {
         Uuid topic1Id = Uuid.randomUuid();
         topicsIds.put(topic1Partition.topic(), topic1Id);
 
-        MetadataCache cache = new MetadataCache("clusterId",
+        MetadataSnapshot cache = new MetadataSnapshot("clusterId",
             nodesById,
             Collections.singleton(partitionMetadata1),
             Collections.emptySet(),
@@ -155,7 +156,7 @@ public class MetadataCacheTest {
         topicIds.put("topic1", Uuid.randomUuid());
         topicIds.put("topic2", Uuid.randomUuid());
 
-        MetadataCache cache = new MetadataCache("clusterId",
+        MetadataSnapshot cache = new MetadataSnapshot("clusterId",
                 Collections.singletonMap(6, new Node(6, "localhost", 2077)),
                 Collections.emptyList(),
                 Collections.emptySet(),
@@ -174,7 +175,7 @@ public class MetadataCacheTest {
     public void testEmptyTopicNamesCacheBuiltFromTopicIds() {
         Map<String, Uuid> topicIds = new HashMap<>();
 
-        MetadataCache cache = new MetadataCache("clusterId",
+        MetadataSnapshot cache = new MetadataSnapshot("clusterId",
                 Collections.singletonMap(6, new Node(6, "localhost", 2077)),
                 Collections.emptyList(),
                 Collections.emptySet(),
@@ -185,4 +186,49 @@ public class MetadataCacheTest {
         assertEquals(Collections.emptyMap(), cache.topicNames());
     }
 
+    @Test
+    public void testLeaderEpochFor() {
+        // Setup partition 0 with a leader-epoch of 10.
+        TopicPartition topicPartition1 = new TopicPartition("topic", 0);
+        MetadataResponse.PartitionMetadata partitionMetadata1 = new 
MetadataResponse.PartitionMetadata(
+            Errors.NONE,
+            topicPartition1,
+            Optional.of(5),
+            Optional.of(10),
+            Arrays.asList(5, 6, 7),
+            Arrays.asList(5, 6, 7),
+            Collections.emptyList());
+
+        // Setup partition 1 with an unknown leader epoch.
+        TopicPartition topicPartition2 = new TopicPartition("topic", 1);
+        MetadataResponse.PartitionMetadata partitionMetadata2 = new 
MetadataResponse.PartitionMetadata(
+            Errors.NONE,
+            topicPartition2,
+            Optional.of(5),
+            Optional.empty(),
+            Arrays.asList(5, 6, 7),
+            Arrays.asList(5, 6, 7),
+            Collections.emptyList());
+
+        Map<Integer, Node> nodesById = new HashMap<>();
+        nodesById.put(5, new Node(5, "localhost", 2077));
+        nodesById.put(6, new Node(6, "localhost", 2078));
+        nodesById.put(7, new Node(7, "localhost", 2079));
+
+        MetadataSnapshot cache = new MetadataSnapshot("clusterId",
+            nodesById,
+            Arrays.asList(partitionMetadata1, partitionMetadata2),
+            Collections.emptySet(),
+            Collections.emptySet(),
+            Collections.emptySet(),
+            null,
+            Collections.emptyMap());
+
+        assertEquals(OptionalInt.of(10), 
cache.leaderEpochFor(topicPartition1));
+
+        assertEquals(OptionalInt.empty(), 
cache.leaderEpochFor(topicPartition2));
+
+        assertEquals(OptionalInt.empty(), cache.leaderEpochFor(new 
TopicPartition("topic_missing", 0)));
+    }
+
 }
diff --git a/clients/src/test/java/org/apache/kafka/clients/MetadataTest.java 
b/clients/src/test/java/org/apache/kafka/clients/MetadataTest.java
index d7c91fdffad..600fc23ecb9 100644
--- a/clients/src/test/java/org/apache/kafka/clients/MetadataTest.java
+++ b/clients/src/test/java/org/apache/kafka/clients/MetadataTest.java
@@ -604,11 +604,11 @@ public class MetadataTest {
 
         // Sentinel instances
         InetSocketAddress address = 
InetSocketAddress.createUnresolved("localhost", 0);
-        Cluster fromMetadata = 
MetadataCache.bootstrap(Collections.singletonList(address)).cluster();
+        Cluster fromMetadata = 
MetadataSnapshot.bootstrap(Collections.singletonList(address)).cluster();
         Cluster fromCluster = 
Cluster.bootstrap(Collections.singletonList(address));
         assertEquals(fromMetadata, fromCluster);
 
-        Cluster fromMetadataEmpty = MetadataCache.empty().cluster();
+        Cluster fromMetadataEmpty = MetadataSnapshot.empty().cluster();
         Cluster fromClusterEmpty = Cluster.empty();
         assertEquals(fromMetadataEmpty, fromClusterEmpty);
     }
diff --git 
a/clients/src/test/java/org/apache/kafka/clients/producer/internals/ProducerBatchTest.java
 
b/clients/src/test/java/org/apache/kafka/clients/producer/internals/ProducerBatchTest.java
index 24629b612b2..7f98e08e534 100644
--- 
a/clients/src/test/java/org/apache/kafka/clients/producer/internals/ProducerBatchTest.java
+++ 
b/clients/src/test/java/org/apache/kafka/clients/producer/internals/ProducerBatchTest.java
@@ -16,7 +16,7 @@
  */
 package org.apache.kafka.clients.producer.internals;
 
-import java.util.Optional;
+import java.util.OptionalInt;
 import org.apache.kafka.clients.producer.Callback;
 import org.apache.kafka.clients.producer.RecordMetadata;
 import org.apache.kafka.common.KafkaException;
@@ -276,41 +276,55 @@ public class ProducerBatchTest {
         ProducerBatch batch = new ProducerBatch(new TopicPartition("topic", 
1), memoryRecordsBuilder, now);
 
         // Starting state for the batch, no attempt made to send it yet.
-        assertEquals(Optional.empty(), batch.currentLeaderEpoch());
+        assertEquals(OptionalInt.empty(), batch.currentLeaderEpoch());
         assertEquals(0, batch.attemptsWhenLeaderLastChanged()); // default 
value
-        batch.maybeUpdateLeaderEpoch(Optional.empty());
+        batch.maybeUpdateLeaderEpoch(OptionalInt.empty());
         assertFalse(batch.hasLeaderChangedForTheOngoingRetry());
 
         // 1st attempt[Not a retry] to send the batch.
         // Check leader isn't flagged as a new leader.
         int batchLeaderEpoch = 100;
-        batch.maybeUpdateLeaderEpoch(Optional.of(batchLeaderEpoch));
+        batch.maybeUpdateLeaderEpoch(OptionalInt.of(batchLeaderEpoch));
         assertFalse(batch.hasLeaderChangedForTheOngoingRetry(), "batch leader 
is assigned for 1st time");
-        assertEquals(batchLeaderEpoch, batch.currentLeaderEpoch().get());
+        assertEquals(batchLeaderEpoch, batch.currentLeaderEpoch().getAsInt());
         assertEquals(0, batch.attemptsWhenLeaderLastChanged());
 
         // 2nd attempt[1st retry] to send the batch to a new leader.
         // Check leader change is detected.
         batchLeaderEpoch = 101;
         batch.reenqueued(0);
-        batch.maybeUpdateLeaderEpoch(Optional.of(batchLeaderEpoch));
+        batch.maybeUpdateLeaderEpoch(OptionalInt.of(batchLeaderEpoch));
         assertTrue(batch.hasLeaderChangedForTheOngoingRetry(), "batch leader 
has changed");
-        assertEquals(batchLeaderEpoch, batch.currentLeaderEpoch().get());
+        assertEquals(batchLeaderEpoch, batch.currentLeaderEpoch().getAsInt());
         assertEquals(1, batch.attemptsWhenLeaderLastChanged());
 
         // 2nd attempt[1st retry] still ongoing, yet to be made.
         // Check same leaderEpoch(101) is still considered as a leader-change.
-        batch.maybeUpdateLeaderEpoch(Optional.of(batchLeaderEpoch));
+        batch.maybeUpdateLeaderEpoch(OptionalInt.of(batchLeaderEpoch));
         assertTrue(batch.hasLeaderChangedForTheOngoingRetry(), "batch leader 
has changed");
-        assertEquals(batchLeaderEpoch, batch.currentLeaderEpoch().get());
+        assertEquals(batchLeaderEpoch, batch.currentLeaderEpoch().getAsInt());
         assertEquals(1, batch.attemptsWhenLeaderLastChanged());
 
         // 3rd attempt[2nd retry] to the same leader-epoch(101).
         // Check same leaderEpoch(101) as not detected as a leader-change.
         batch.reenqueued(0);
-        batch.maybeUpdateLeaderEpoch(Optional.of(batchLeaderEpoch));
+        batch.maybeUpdateLeaderEpoch(OptionalInt.of(batchLeaderEpoch));
         assertFalse(batch.hasLeaderChangedForTheOngoingRetry(), "batch leader 
has not changed");
-        assertEquals(batchLeaderEpoch, batch.currentLeaderEpoch().get());
+        assertEquals(batchLeaderEpoch, batch.currentLeaderEpoch().getAsInt());
+        assertEquals(1, batch.attemptsWhenLeaderLastChanged());
+
+        // Attempt made to update batch leader-epoch to an older 
leader-epoch(100).
+        // Check batch leader-epoch remains unchanged as 101.
+        batch.maybeUpdateLeaderEpoch(OptionalInt.of(batchLeaderEpoch - 1));
+        assertFalse(batch.hasLeaderChangedForTheOngoingRetry(), "batch leader 
has not changed");
+        assertEquals(batchLeaderEpoch, batch.currentLeaderEpoch().getAsInt());
+        assertEquals(1, batch.attemptsWhenLeaderLastChanged());
+
+        // Attempt made to update batch leader-epoch to an unknown 
leader(optional.empty())
+        // Check batch leader-epoch remains unchanged as 101.
+        batch.maybeUpdateLeaderEpoch(OptionalInt.empty());
+        assertFalse(batch.hasLeaderChangedForTheOngoingRetry(), "batch leader 
has not changed");
+        assertEquals(batchLeaderEpoch, batch.currentLeaderEpoch().getAsInt());
         assertEquals(1, batch.attemptsWhenLeaderLastChanged());
     }
 
diff --git 
a/clients/src/test/java/org/apache/kafka/clients/producer/internals/RecordAccumulatorTest.java
 
b/clients/src/test/java/org/apache/kafka/clients/producer/internals/RecordAccumulatorTest.java
index a046efe2cb2..9ed0e703809 100644
--- 
a/clients/src/test/java/org/apache/kafka/clients/producer/internals/RecordAccumulatorTest.java
+++ 
b/clients/src/test/java/org/apache/kafka/clients/producer/internals/RecordAccumulatorTest.java
@@ -17,10 +17,11 @@
 package org.apache.kafka.clients.producer.internals;
 
 import java.util.Optional;
+import java.util.OptionalInt;
 import java.util.function.Function;
 import org.apache.kafka.clients.ApiVersions;
 import org.apache.kafka.clients.CommonClientConfigs;
-import org.apache.kafka.clients.Metadata;
+import org.apache.kafka.clients.MetadataSnapshot;
 import org.apache.kafka.clients.NodeApiVersions;
 import org.apache.kafka.clients.producer.Callback;
 import org.apache.kafka.clients.producer.Partitioner;
@@ -33,6 +34,7 @@ import org.apache.kafka.common.TopicPartition;
 import org.apache.kafka.common.errors.UnsupportedVersionException;
 import org.apache.kafka.common.metrics.Metrics;
 import org.apache.kafka.common.protocol.ApiKeys;
+import org.apache.kafka.common.protocol.Errors;
 import org.apache.kafka.common.record.CompressionRatioEstimator;
 import org.apache.kafka.common.record.CompressionType;
 import org.apache.kafka.common.record.DefaultRecord;
@@ -42,6 +44,8 @@ import org.apache.kafka.common.record.MemoryRecordsBuilder;
 import org.apache.kafka.common.record.MutableRecordBatch;
 import org.apache.kafka.common.record.Record;
 import org.apache.kafka.common.record.TimestampType;
+import org.apache.kafka.common.requests.MetadataResponse;
+import org.apache.kafka.common.requests.MetadataResponse.PartitionMetadata;
 import org.apache.kafka.common.utils.LogContext;
 import org.apache.kafka.common.utils.MockTime;
 import org.apache.kafka.common.utils.ProducerIdAndEpoch;
@@ -89,29 +93,38 @@ public class RecordAccumulatorTest {
     private TopicPartition tp1 = new TopicPartition(topic, partition1);
     private TopicPartition tp2 = new TopicPartition(topic, partition2);
     private TopicPartition tp3 = new TopicPartition(topic, partition3);
-    private PartitionInfo part1 = new PartitionInfo(topic, partition1, node1, 
null, null);
-    private PartitionInfo part2 = new PartitionInfo(topic, partition2, node1, 
null, null);
-    private PartitionInfo part3 = new PartitionInfo(topic, partition3, node2, 
null, null);
+
+    private PartitionMetadata partMetadata1 = new 
PartitionMetadata(Errors.NONE, tp1, Optional.of(node1.id()), Optional.empty(), 
null, null, null);
+    private PartitionMetadata partMetadata2 = new 
PartitionMetadata(Errors.NONE, tp2, Optional.of(node1.id()), Optional.empty(), 
null, null, null);
+    private PartitionMetadata partMetadata3 = new 
PartitionMetadata(Errors.NONE, tp3, Optional.of(node2.id()), Optional.empty(), 
null, null, null);
+    private List<PartitionMetadata> partMetadatas = new 
ArrayList<>(Arrays.asList(partMetadata1, partMetadata2, partMetadata3));
+
+    private Map<Integer, Node> nodes = Arrays.asList(node1, 
node2).stream().collect(Collectors.toMap(Node::id, Function.identity()));
+    private MetadataSnapshot metadataCache = new MetadataSnapshot(null,
+        nodes,
+        partMetadatas,
+        Collections.emptySet(),
+        Collections.emptySet(),
+        Collections.emptySet(),
+        null,
+        Collections.emptyMap());
+
+    private Cluster cluster = metadataCache.cluster();
+
     private MockTime time = new MockTime();
     private byte[] key = "key".getBytes();
     private byte[] value = "value".getBytes();
     private int msgSize = DefaultRecord.sizeInBytes(0, 0, key.length, 
value.length, Record.EMPTY_HEADERS);
-    Metadata metadataMock;
-    private Cluster cluster = new Cluster(null, Arrays.asList(node1, node2), 
Arrays.asList(part1, part2, part3),
-        Collections.emptySet(), Collections.emptySet());
+
     private Metrics metrics = new Metrics(time);
     private final long maxBlockTimeMs = 1000;
     private final LogContext logContext = new LogContext();
 
-    @BeforeEach
-    public void setup() {
-        metadataMock = setupMetadata(cluster);
-    }
+    @BeforeEach void setup() {}
 
     @AfterEach
     public void teardown() {
         this.metrics.close();
-        Mockito.reset(metadataMock);
     }
 
     @Test
@@ -120,13 +133,31 @@ public class RecordAccumulatorTest {
         // add tp-4
         int partition4 = 3;
         TopicPartition tp4 = new TopicPartition(topic, partition4);
-        PartitionInfo part4 = new PartitionInfo(topic, partition4, node2, 
null, null);
+        PartitionMetadata partMetadata4 = new PartitionMetadata(Errors.NONE, 
tp4, Optional.of(node2.id()), Optional.empty(), null, null, null);
+        partMetadatas.add(partMetadata4);
+
+        // This test requires that partitions to be drained in order for each 
node i.e.
+        // node1 -> tp1, tp3, and node2 -> tp2, tp4.
+        // So setup cluster with this order, and pass this cluster to 
MetadataCache to preserve this order.
+        PartitionInfo part1 = MetadataResponse.toPartitionInfo(partMetadata1, 
nodes);
+        PartitionInfo part2 = MetadataResponse.toPartitionInfo(partMetadata2, 
nodes);
+        PartitionInfo part3 = MetadataResponse.toPartitionInfo(partMetadata3, 
nodes);
+        PartitionInfo part4 = MetadataResponse.toPartitionInfo(partMetadata4, 
nodes);
+        Cluster cluster = new Cluster(null, Arrays.asList(node1, node2), 
Arrays.asList(part1, part2, part3, part4),
+            Collections.emptySet(), Collections.emptySet());
 
+        metadataCache = new MetadataSnapshot(null,
+            nodes,
+            partMetadatas,
+            Collections.emptySet(),
+            Collections.emptySet(),
+            Collections.emptySet(),
+            null,
+            Collections.emptyMap(),
+            cluster);
         long batchSize = value.length + 
DefaultRecordBatch.RECORD_BATCH_OVERHEAD;
         RecordAccumulator accum = createTestRecordAccumulator((int) batchSize, 
Integer.MAX_VALUE, CompressionType.NONE, 10);
-        Cluster cluster = new Cluster(null, Arrays.asList(node1, node2), 
Arrays.asList(part1, part2, part3, part4),
-                Collections.emptySet(), Collections.emptySet());
-        metadataMock = setupMetadata(cluster);
+
 
         //  initial data
         accum.append(topic, partition1, 0L, key, value, Record.EMPTY_HEADERS, 
null, maxBlockTimeMs, false, time.milliseconds(), cluster);
@@ -135,7 +166,7 @@ public class RecordAccumulatorTest {
         accum.append(topic, partition4, 0L, key, value, Record.EMPTY_HEADERS, 
null, maxBlockTimeMs, false, time.milliseconds(), cluster);
 
         // drain batches from 2 nodes: node1 => tp1, node2 => tp3, because the 
max request size is full after the first batch drained
-        Map<Integer, List<ProducerBatch>> batches1 = accum.drain(metadataMock, 
new HashSet<>(Arrays.asList(node1, node2)), (int) batchSize, 0);
+        Map<Integer, List<ProducerBatch>> batches1 = 
accum.drain(metadataCache, new HashSet<>(Arrays.asList(node1, node2)), (int) 
batchSize, 0);
         verifyTopicPartitionInBatches(batches1, tp1, tp3);
 
         // add record for tp1, tp3
@@ -144,11 +175,11 @@ public class RecordAccumulatorTest {
 
         // drain batches from 2 nodes: node1 => tp2, node2 => tp4, because the 
max request size is full after the first batch drained
         // The drain index should start from next topic partition, that is, 
node1 => tp2, node2 => tp4
-        Map<Integer, List<ProducerBatch>> batches2 = accum.drain(metadataMock, 
new HashSet<>(Arrays.asList(node1, node2)), (int) batchSize, 0);
+        Map<Integer, List<ProducerBatch>> batches2 = 
accum.drain(metadataCache, new HashSet<>(Arrays.asList(node1, node2)), (int) 
batchSize, 0);
         verifyTopicPartitionInBatches(batches2, tp2, tp4);
 
         // make sure in next run, the drain index will start from the beginning
-        Map<Integer, List<ProducerBatch>> batches3 = accum.drain(metadataMock, 
new HashSet<>(Arrays.asList(node1, node2)), (int) batchSize, 0);
+        Map<Integer, List<ProducerBatch>> batches3 = 
accum.drain(metadataCache, new HashSet<>(Arrays.asList(node1, node2)), (int) 
batchSize, 0);
         verifyTopicPartitionInBatches(batches3, tp1, tp3);
 
         // add record for tp2, tp3, tp4 and mute the tp4
@@ -157,7 +188,7 @@ public class RecordAccumulatorTest {
         accum.append(topic, partition4, 0L, key, value, Record.EMPTY_HEADERS, 
null, maxBlockTimeMs, false, time.milliseconds(), cluster);
         accum.mutePartition(tp4);
         // drain batches from 2 nodes: node1 => tp2, node2 => tp3 (because tp4 
is muted)
-        Map<Integer, List<ProducerBatch>> batches4 = accum.drain(metadataMock, 
new HashSet<>(Arrays.asList(node1, node2)), (int) batchSize, 0);
+        Map<Integer, List<ProducerBatch>> batches4 = 
accum.drain(metadataCache, new HashSet<>(Arrays.asList(node1, node2)), (int) 
batchSize, 0);
         verifyTopicPartitionInBatches(batches4, tp2, tp3);
 
         // add record for tp1, tp2, tp3, and unmute tp4
@@ -166,7 +197,7 @@ public class RecordAccumulatorTest {
         accum.append(topic, partition3, 0L, key, value, Record.EMPTY_HEADERS, 
null, maxBlockTimeMs, false, time.milliseconds(), cluster);
         accum.unmutePartition(tp4);
         // set maxSize as a max value, so that the all partitions in 2 nodes 
should be drained: node1 => [tp1, tp2], node2 => [tp3, tp4]
-        Map<Integer, List<ProducerBatch>> batches5 = accum.drain(metadataMock, 
new HashSet<>(Arrays.asList(node1, node2)), Integer.MAX_VALUE, 0);
+        Map<Integer, List<ProducerBatch>> batches5 = 
accum.drain(metadataCache, new HashSet<>(Arrays.asList(node1, node2)), 
Integer.MAX_VALUE, 0);
         verifyTopicPartitionInBatches(batches5, tp1, tp2, tp3, tp4);
     }
 
@@ -197,25 +228,25 @@ public class RecordAccumulatorTest {
         int appends = expectedNumAppends(batchSize);
         for (int i = 0; i < appends; i++) {
             // append to the first batch
-            accum.append(topic, partition1, 0L, key, value, 
Record.EMPTY_HEADERS, null, maxBlockTimeMs, false, time.milliseconds(), 
cluster);
+            accum.append(topic, partition1, 0L, key, value, 
Record.EMPTY_HEADERS, null, maxBlockTimeMs, false, time.milliseconds(), 
metadataCache.cluster());
             Deque<ProducerBatch> partitionBatches = accum.getDeque(tp1);
             assertEquals(1, partitionBatches.size());
 
             ProducerBatch batch = partitionBatches.peekFirst();
             assertTrue(batch.isWritable());
-            assertEquals(0, accum.ready(metadataMock, now).readyNodes.size(), 
"No partitions should be ready.");
+            assertEquals(0, accum.ready(metadataCache, now).readyNodes.size(), 
"No partitions should be ready.");
         }
 
         // this append doesn't fit in the first batch, so a new batch is 
created and the first batch is closed
 
-        accum.append(topic, partition1, 0L, key, value, Record.EMPTY_HEADERS, 
null, maxBlockTimeMs, false, time.milliseconds(), cluster);
+        accum.append(topic, partition1, 0L, key, value, Record.EMPTY_HEADERS, 
null, maxBlockTimeMs, false, time.milliseconds(), metadataCache.cluster());
         Deque<ProducerBatch> partitionBatches = accum.getDeque(tp1);
         assertEquals(2, partitionBatches.size());
         Iterator<ProducerBatch> partitionBatchesIterator = 
partitionBatches.iterator();
         assertTrue(partitionBatchesIterator.next().isWritable());
-        assertEquals(Collections.singleton(node1), accum.ready(metadataMock, 
time.milliseconds()).readyNodes, "Our partition's leader should be ready");
+        assertEquals(Collections.singleton(node1), accum.ready(metadataCache, 
time.milliseconds()).readyNodes, "Our partition's leader should be ready");
 
-        List<ProducerBatch> batches = accum.drain(metadataMock, 
Collections.singleton(node1), Integer.MAX_VALUE, 0).get(node1.id());
+        List<ProducerBatch> batches = accum.drain(metadataCache, 
Collections.singleton(node1), Integer.MAX_VALUE, 0).get(node1.id());
         assertEquals(1, batches.size());
         ProducerBatch batch = batches.get(0);
 
@@ -243,8 +274,8 @@ public class RecordAccumulatorTest {
         byte[] value = new byte[2 * batchSize];
         RecordAccumulator accum = createTestRecordAccumulator(
                 batchSize + DefaultRecordBatch.RECORD_BATCH_OVERHEAD, 10 * 
1024, compressionType, 0);
-        accum.append(topic, partition1, 0L, key, value, Record.EMPTY_HEADERS, 
null, maxBlockTimeMs, false, time.milliseconds(), cluster);
-        assertEquals(Collections.singleton(node1), accum.ready(metadataMock, 
time.milliseconds()).readyNodes, "Our partition's leader should be ready");
+        accum.append(topic, partition1, 0L, key, value, Record.EMPTY_HEADERS, 
null, maxBlockTimeMs, false, time.milliseconds(), metadataCache.cluster());
+        assertEquals(Collections.singleton(node1), accum.ready(metadataCache, 
time.milliseconds()).readyNodes, "Our partition's leader should be ready");
 
         Deque<ProducerBatch> batches = accum.getDeque(tp1);
         assertEquals(1, batches.size());
@@ -281,8 +312,8 @@ public class RecordAccumulatorTest {
 
         RecordAccumulator accum = createTestRecordAccumulator(
                 batchSize + DefaultRecordBatch.RECORD_BATCH_OVERHEAD, 10 * 
1024, compressionType, 0);
-        accum.append(topic, partition1, 0L, key, value, Record.EMPTY_HEADERS, 
null, maxBlockTimeMs, false, time.milliseconds(), cluster);
-        assertEquals(Collections.singleton(node1), accum.ready(metadataMock, 
time.milliseconds()).readyNodes, "Our partition's leader should be ready");
+        accum.append(topic, partition1, 0L, key, value, Record.EMPTY_HEADERS, 
null, maxBlockTimeMs, false, time.milliseconds(), metadataCache.cluster());
+        assertEquals(Collections.singleton(node1), accum.ready(metadataCache, 
time.milliseconds()).readyNodes, "Our partition's leader should be ready");
 
         Deque<ProducerBatch> batches = accum.getDeque(tp1);
         assertEquals(1, batches.size());
@@ -306,10 +337,10 @@ public class RecordAccumulatorTest {
         RecordAccumulator accum = createTestRecordAccumulator(
                 1024 + DefaultRecordBatch.RECORD_BATCH_OVERHEAD, 10 * 1024, 
CompressionType.NONE, lingerMs);
         accum.append(topic, partition1, 0L, key, value, Record.EMPTY_HEADERS, 
null, maxBlockTimeMs, false, time.milliseconds(), cluster);
-        assertEquals(0, accum.ready(metadataMock, 
time.milliseconds()).readyNodes.size(), "No partitions should be ready");
+        assertEquals(0, accum.ready(metadataCache, 
time.milliseconds()).readyNodes.size(), "No partitions should be ready");
         time.sleep(10);
-        assertEquals(Collections.singleton(node1), accum.ready(metadataMock, 
time.milliseconds()).readyNodes, "Our partition's leader should be ready");
-        List<ProducerBatch> batches = accum.drain(metadataMock, 
Collections.singleton(node1), Integer.MAX_VALUE, 0).get(node1.id());
+        assertEquals(Collections.singleton(node1), accum.ready(metadataCache, 
time.milliseconds()).readyNodes, "Our partition's leader should be ready");
+        List<ProducerBatch> batches = accum.drain(metadataCache, 
Collections.singleton(node1), Integer.MAX_VALUE, 0).get(node1.id());
         assertEquals(1, batches.size());
         ProducerBatch batch = batches.get(0);
 
@@ -330,9 +361,9 @@ public class RecordAccumulatorTest {
             for (int i = 0; i < appends; i++)
                 accum.append(tp.topic(), tp.partition(), 0L, key, value, 
Record.EMPTY_HEADERS, null, maxBlockTimeMs, false, time.milliseconds(), 
cluster);
         }
-        assertEquals(Collections.singleton(node1), accum.ready(metadataMock, 
time.milliseconds()).readyNodes, "Partition's leader should be ready");
+        assertEquals(Collections.singleton(node1), accum.ready(metadataCache, 
time.milliseconds()).readyNodes, "Partition's leader should be ready");
 
-        List<ProducerBatch> batches = accum.drain(metadataMock, 
Collections.singleton(node1), 1024, 0).get(node1.id());
+        List<ProducerBatch> batches = accum.drain(metadataCache, 
Collections.singleton(node1), 1024, 0).get(node1.id());
         assertEquals(1, batches.size(), "But due to size bound only one 
partition should have been retrieved");
     }
 
@@ -361,8 +392,8 @@ public class RecordAccumulatorTest {
         int read = 0;
         long now = time.milliseconds();
         while (read < numThreads * msgs) {
-            Set<Node> nodes = accum.ready(metadataMock, now).readyNodes;
-            List<ProducerBatch> batches = accum.drain(metadataMock, nodes, 5 * 
1024, 0).get(node1.id());
+            Set<Node> nodes = accum.ready(metadataCache, now).readyNodes;
+            List<ProducerBatch> batches = accum.drain(metadataCache, nodes, 5 
* 1024, 0).get(node1.id());
             if (batches != null) {
                 for (ProducerBatch batch : batches) {
                     for (Record record : batch.records().records())
@@ -393,7 +424,7 @@ public class RecordAccumulatorTest {
         // Partition on node1 only
         for (int i = 0; i < appends; i++)
             accum.append(topic, partition1, 0L, key, value, 
Record.EMPTY_HEADERS, null, maxBlockTimeMs, false, time.milliseconds(), 
cluster);
-        RecordAccumulator.ReadyCheckResult result = accum.ready(metadataMock, 
time.milliseconds());
+        RecordAccumulator.ReadyCheckResult result = accum.ready(metadataCache, 
time.milliseconds());
         assertEquals(0, result.readyNodes.size(), "No nodes should be ready.");
         assertEquals(lingerMs, result.nextReadyCheckDelayMs, "Next check time 
should be the linger time");
 
@@ -402,14 +433,14 @@ public class RecordAccumulatorTest {
         // Add partition on node2 only
         for (int i = 0; i < appends; i++)
             accum.append(topic, partition3, 0L, key, value, 
Record.EMPTY_HEADERS, null, maxBlockTimeMs, false, time.milliseconds(), 
cluster);
-        result = accum.ready(metadataMock, time.milliseconds());
+        result = accum.ready(metadataCache, time.milliseconds());
         assertEquals(0, result.readyNodes.size(), "No nodes should be ready.");
         assertEquals(lingerMs / 2, result.nextReadyCheckDelayMs, "Next check 
time should be defined by node1, half remaining linger time");
 
         // Add data for another partition on node1, enough to make data 
sendable immediately
         for (int i = 0; i < appends + 1; i++)
             accum.append(topic, partition2, 0L, key, value, 
Record.EMPTY_HEADERS, null, maxBlockTimeMs, false, time.milliseconds(), 
cluster);
-        result = accum.ready(metadataMock, time.milliseconds());
+        result = accum.ready(metadataCache, time.milliseconds());
         assertEquals(Collections.singleton(node1), result.readyNodes, "Node1 
should be ready");
         // Note this can actually be < linger time because it may use delays 
from partitions that aren't sendable
         // but have leaders with other sendable data.
@@ -433,9 +464,9 @@ public class RecordAccumulatorTest {
 
         long now = time.milliseconds();
         accum.append(topic, partition1, 0L, key, value, Record.EMPTY_HEADERS, 
null, maxBlockTimeMs, false, time.milliseconds(), cluster);
-        RecordAccumulator.ReadyCheckResult result = accum.ready(metadataMock, 
now + lingerMs + 1);
+        RecordAccumulator.ReadyCheckResult result = accum.ready(metadataCache, 
now + lingerMs + 1);
         assertEquals(Collections.singleton(node1), result.readyNodes, "Node1 
should be ready");
-        Map<Integer, List<ProducerBatch>> batches = accum.drain(metadataMock, 
result.readyNodes, Integer.MAX_VALUE, now + lingerMs + 1);
+        Map<Integer, List<ProducerBatch>> batches = accum.drain(metadataCache, 
result.readyNodes, Integer.MAX_VALUE, now + lingerMs + 1);
         assertEquals(1, batches.size(), "Node1 should be the only ready 
node.");
         assertEquals(1, batches.get(0).size(), "Partition 0 should only have 
one batch drained.");
 
@@ -445,37 +476,37 @@ public class RecordAccumulatorTest {
 
         // Put message for partition 1 into accumulator
         accum.append(topic, partition2, 0L, key, value, Record.EMPTY_HEADERS, 
null, maxBlockTimeMs, false, time.milliseconds(), cluster);
-        result = accum.ready(metadataMock, now + lingerMs + 1);
+        result = accum.ready(metadataCache, now + lingerMs + 1);
         assertEquals(Collections.singleton(node1), result.readyNodes, "Node1 
should be ready");
 
         // tp1 should backoff while tp2 should not
-        batches = accum.drain(metadataMock, result.readyNodes, 
Integer.MAX_VALUE, now + lingerMs + 1);
+        batches = accum.drain(metadataCache, result.readyNodes, 
Integer.MAX_VALUE, now + lingerMs + 1);
         assertEquals(1, batches.size(), "Node1 should be the only ready 
node.");
         assertEquals(1, batches.get(0).size(), "Node1 should only have one 
batch drained.");
         assertEquals(tp2, batches.get(0).get(0).topicPartition, "Node1 should 
only have one batch for partition 1.");
 
         // Partition 0 can be drained after retry backoff
         long upperBoundBackoffMs = (long) (retryBackoffMs * (1 + 
CommonClientConfigs.RETRY_BACKOFF_JITTER));
-        result = accum.ready(metadataMock, now + upperBoundBackoffMs + 1);
+        result = accum.ready(metadataCache, now + upperBoundBackoffMs + 1);
         assertEquals(Collections.singleton(node1), result.readyNodes, "Node1 
should be ready");
-        batches = accum.drain(metadataMock, result.readyNodes, 
Integer.MAX_VALUE, now + upperBoundBackoffMs + 1);
+        batches = accum.drain(metadataCache, result.readyNodes, 
Integer.MAX_VALUE, now + upperBoundBackoffMs + 1);
         assertEquals(1, batches.size(), "Node1 should be the only ready 
node.");
         assertEquals(1, batches.get(0).size(), "Node1 should only have one 
batch drained.");
         assertEquals(tp1, batches.get(0).get(0).topicPartition, "Node1 should 
only have one batch for partition 0.");
     }
 
-    private Map<Integer, List<ProducerBatch>> drainAndCheckBatchAmount(Cluster 
cluster, Node leader, RecordAccumulator accum, long now, int expected) {
-        metadataMock = setupMetadata(cluster);
-        RecordAccumulator.ReadyCheckResult result = accum.ready(metadataMock, 
now);
+    private Map<Integer, List<ProducerBatch>> drainAndCheckBatchAmount(
+        MetadataSnapshot metadataCache, Node leader, RecordAccumulator accum, 
long now, int expected) {
+        RecordAccumulator.ReadyCheckResult result = accum.ready(metadataCache, 
now);
         if (expected > 0) {
             assertEquals(Collections.singleton(leader), result.readyNodes, 
"Leader should be ready");
-            Map<Integer, List<ProducerBatch>> batches = 
accum.drain(metadataMock, result.readyNodes, Integer.MAX_VALUE, now);
+            Map<Integer, List<ProducerBatch>> batches = 
accum.drain(metadataCache, result.readyNodes, Integer.MAX_VALUE, now);
             assertEquals(expected, batches.size(), "Leader should be the only 
ready node.");
             assertEquals(expected, batches.get(leader.id()).size(), "Partition 
should only have " + expected + " batch drained.");
             return batches;
         } else {
             assertEquals(0, result.readyNodes.size(), "Leader should not be 
ready");
-            Map<Integer, List<ProducerBatch>> batches = 
accum.drain(metadataMock, result.readyNodes, Integer.MAX_VALUE, now);
+            Map<Integer, List<ProducerBatch>> batches = 
accum.drain(metadataCache, result.readyNodes, Integer.MAX_VALUE, now);
             assertEquals(0, batches.size(), "Leader should not be drained.");
             return null;
         }
@@ -501,7 +532,7 @@ public class RecordAccumulatorTest {
         accum.append(topic, partition1, 0L, key, value, Record.EMPTY_HEADERS, 
null, maxBlockTimeMs, false, time.milliseconds(), cluster);
 
         // No backoff for initial attempt
-        Map<Integer, List<ProducerBatch>> batches = 
drainAndCheckBatchAmount(cluster, node1, accum, now + lingerMs + 1, 1);
+        Map<Integer, List<ProducerBatch>> batches = 
drainAndCheckBatchAmount(metadataCache, node1, accum, now + lingerMs + 1, 1);
         ProducerBatch batch = batches.get(0).get(0);
         long currentRetryBackoffMs = 0;
 
@@ -513,9 +544,9 @@ public class RecordAccumulatorTest {
             long upperBoundBackoffMs = (long) (retryBackoffMs * 
Math.pow(CommonClientConfigs.RETRY_BACKOFF_EXP_BASE, i) * (1 + 
CommonClientConfigs.RETRY_BACKOFF_JITTER));
             currentRetryBackoffMs = upperBoundBackoffMs;
             // Should back off
-            drainAndCheckBatchAmount(cluster, node1, accum, initial + 
lowerBoundBackoffMs - 1, 0);
+            drainAndCheckBatchAmount(metadataCache, node1, accum, initial + 
lowerBoundBackoffMs - 1, 0);
             // Should not back off
-            drainAndCheckBatchAmount(cluster, node1, accum, initial + 
upperBoundBackoffMs + 1, 1);
+            drainAndCheckBatchAmount(metadataCache, node1, accum, initial + 
upperBoundBackoffMs + 1, 1);
         }
     }
 
@@ -529,14 +560,28 @@ public class RecordAccumulatorTest {
         int batchSize = 1024 + DefaultRecordBatch.RECORD_BATCH_OVERHEAD;
         String metricGrpName = "producer-metrics";
 
-        PartitionInfo part1 = new PartitionInfo(topic, partition1, node1, 
null, null);
-        PartitionInfo part1Change = new PartitionInfo(topic, partition1, 
node2, null, null);
-        PartitionInfo part2 = new PartitionInfo(topic, partition2, node1, 
null, null);
-        PartitionInfo part3 = new PartitionInfo(topic, partition3, node2, 
null, null);
-        Cluster cluster = new Cluster(null, Arrays.asList(node1, node2), 
Arrays.asList(part1, part2, part3),
-                Collections.emptySet(), Collections.emptySet());
-        Cluster clusterChange = new Cluster(null, Arrays.asList(node1, node2), 
Arrays.asList(part1Change, part2, part3),
-                Collections.emptySet(), Collections.emptySet());
+        PartitionMetadata part1Metadata = new PartitionMetadata(Errors.NONE, 
tp1, Optional.of(node1.id()), Optional.empty(), null, null, null);
+        PartitionMetadata part1MetadataChange = new 
PartitionMetadata(Errors.NONE, tp1, Optional.of(node2.id()), Optional.empty(), 
null, null, null);
+        PartitionMetadata part2Metadata = new PartitionMetadata(Errors.NONE, 
tp2, Optional.of(node1.id()), Optional.empty(), null, null, null);
+        PartitionMetadata part3Metadata = new PartitionMetadata(Errors.NONE, 
tp3, Optional.of(node2.id()), Optional.empty(), null, null, null);
+
+        MetadataSnapshot metadataCache = new MetadataSnapshot(null,
+            nodes,
+            Arrays.asList(part1Metadata, part2Metadata, part3Metadata),
+            Collections.emptySet(),
+            Collections.emptySet(),
+            Collections.emptySet(),
+            null,
+            Collections.emptyMap());
+
+        MetadataSnapshot metadataCacheChange = new MetadataSnapshot(null,
+            nodes,
+            Arrays.asList(part1MetadataChange, part2Metadata, part3Metadata),
+            Collections.emptySet(),
+            Collections.emptySet(),
+            Collections.emptySet(),
+            null,
+            Collections.emptyMap());
 
         final RecordAccumulator accum = new RecordAccumulator(logContext, 
batchSize,
                 CompressionType.NONE, lingerMs, retryBackoffMs, 
retryBackoffMaxMs,
@@ -548,7 +593,7 @@ public class RecordAccumulatorTest {
         accum.append(topic, partition1, 0L, key, value, Record.EMPTY_HEADERS, 
null, maxBlockTimeMs, false, time.milliseconds(), cluster);
 
         // No backoff for initial attempt
-        Map<Integer, List<ProducerBatch>> batches = 
drainAndCheckBatchAmount(cluster, node1, accum, now + lingerMs + 1, 1);
+        Map<Integer, List<ProducerBatch>> batches = 
drainAndCheckBatchAmount(metadataCache, node1, accum, now + lingerMs + 1, 1);
         ProducerBatch batch = batches.get(0).get(0);
 
         long lowerBoundBackoffMs;
@@ -560,9 +605,9 @@ public class RecordAccumulatorTest {
         lowerBoundBackoffMs = (long) (retryBackoffMs * (1 - 
CommonClientConfigs.RETRY_BACKOFF_JITTER));
         upperBoundBackoffMs = (long) (retryBackoffMs * (1 + 
CommonClientConfigs.RETRY_BACKOFF_JITTER));
         // Should back off
-        drainAndCheckBatchAmount(cluster, node1, accum, initial + 
lowerBoundBackoffMs - 1, 0);
+        drainAndCheckBatchAmount(metadataCache, node1, accum, initial + 
lowerBoundBackoffMs - 1, 0);
         // Should not back off
-        drainAndCheckBatchAmount(cluster, node1, accum, initial + 
upperBoundBackoffMs + 1, 1);
+        drainAndCheckBatchAmount(metadataCache, node1, accum, initial + 
upperBoundBackoffMs + 1, 1);
 
         // Retry 2 - delay by retryBackoffMs * 2 +/- jitter
         now = time.milliseconds();
@@ -570,9 +615,9 @@ public class RecordAccumulatorTest {
         lowerBoundBackoffMs = (long) (retryBackoffMs * 
CommonClientConfigs.RETRY_BACKOFF_EXP_BASE * (1 - 
CommonClientConfigs.RETRY_BACKOFF_JITTER));
         upperBoundBackoffMs = (long) (retryBackoffMs * 
CommonClientConfigs.RETRY_BACKOFF_EXP_BASE * (1 + 
CommonClientConfigs.RETRY_BACKOFF_JITTER));
         // Should back off
-        drainAndCheckBatchAmount(cluster, node1, accum, initial + 
lowerBoundBackoffMs - 1, 0);
+        drainAndCheckBatchAmount(metadataCache, node1, accum, initial + 
lowerBoundBackoffMs - 1, 0);
         // Should not back off
-        drainAndCheckBatchAmount(cluster, node1, accum, initial + 
upperBoundBackoffMs + 1, 1);
+        drainAndCheckBatchAmount(metadataCache, node1, accum, initial + 
upperBoundBackoffMs + 1, 1);
 
         // Retry 3 - after a leader change, delay by retryBackoffMs * 2^2 +/- 
jitter (could optimise to do not delay at all)
         now = time.milliseconds();
@@ -580,9 +625,9 @@ public class RecordAccumulatorTest {
         lowerBoundBackoffMs = (long) (retryBackoffMs * 
Math.pow(CommonClientConfigs.RETRY_BACKOFF_EXP_BASE, 2) * (1 - 
CommonClientConfigs.RETRY_BACKOFF_JITTER));
         upperBoundBackoffMs = (long) (retryBackoffMs * 
Math.pow(CommonClientConfigs.RETRY_BACKOFF_EXP_BASE, 2) * (1 + 
CommonClientConfigs.RETRY_BACKOFF_JITTER));
         // Should back off
-        drainAndCheckBatchAmount(clusterChange, node2, accum, initial + 
lowerBoundBackoffMs - 1, 0);
+        drainAndCheckBatchAmount(metadataCacheChange, node2, accum, initial + 
lowerBoundBackoffMs - 1, 0);
         // Should not back off
-        drainAndCheckBatchAmount(clusterChange, node2, accum, initial + 
upperBoundBackoffMs + 1, 1);
+        drainAndCheckBatchAmount(metadataCacheChange, node2, accum, initial + 
upperBoundBackoffMs + 1, 1);
 
         // Retry 4 - delay by retryBackoffMs * 2^3 +/- jitter (capped to 
retryBackoffMaxMs)
         now = time.milliseconds();
@@ -590,9 +635,9 @@ public class RecordAccumulatorTest {
         lowerBoundBackoffMs = (long) (retryBackoffMs * 
Math.pow(CommonClientConfigs.RETRY_BACKOFF_EXP_BASE, 3) * (1 - 
CommonClientConfigs.RETRY_BACKOFF_JITTER));
         upperBoundBackoffMs = retryBackoffMaxMs;
         // Should back off
-        drainAndCheckBatchAmount(clusterChange, node2, accum, initial + 
lowerBoundBackoffMs - 1, 0);
+        drainAndCheckBatchAmount(metadataCacheChange, node2, accum, initial + 
lowerBoundBackoffMs - 1, 0);
         // Should not back off
-        drainAndCheckBatchAmount(clusterChange, node2, accum, initial + 
upperBoundBackoffMs + 1, 1);
+        drainAndCheckBatchAmount(metadataCacheChange, node2, accum, initial + 
upperBoundBackoffMs + 1, 1);
     }
 
     @Test
@@ -605,14 +650,14 @@ public class RecordAccumulatorTest {
             accum.append(topic, i % 3, 0L, key, value, Record.EMPTY_HEADERS, 
null, maxBlockTimeMs, false, time.milliseconds(), cluster);
             assertTrue(accum.hasIncomplete());
         }
-        RecordAccumulator.ReadyCheckResult result = accum.ready(metadataMock, 
time.milliseconds());
+        RecordAccumulator.ReadyCheckResult result = accum.ready(metadataCache, 
time.milliseconds());
         assertEquals(0, result.readyNodes.size(), "No nodes should be ready.");
 
         accum.beginFlush();
-        result = accum.ready(metadataMock, time.milliseconds());
+        result = accum.ready(metadataCache, time.milliseconds());
 
         // drain and deallocate all batches
-        Map<Integer, List<ProducerBatch>> results = accum.drain(metadataMock, 
result.readyNodes, Integer.MAX_VALUE, time.milliseconds());
+        Map<Integer, List<ProducerBatch>> results = accum.drain(metadataCache, 
result.readyNodes, Integer.MAX_VALUE, time.milliseconds());
         assertTrue(accum.hasIncomplete());
 
         for (List<ProducerBatch> batches: results.values())
@@ -672,9 +717,9 @@ public class RecordAccumulatorTest {
         }
         for (int i = 0; i < numRecords; i++)
             accum.append(topic, i % 3, 0L, key, value, null, new 
TestCallback(), maxBlockTimeMs, false, time.milliseconds(), cluster);
-        RecordAccumulator.ReadyCheckResult result = accum.ready(metadataMock, 
time.milliseconds());
+        RecordAccumulator.ReadyCheckResult result = accum.ready(metadataCache, 
time.milliseconds());
         assertFalse(result.readyNodes.isEmpty());
-        Map<Integer, List<ProducerBatch>> drained = accum.drain(metadataMock, 
result.readyNodes, Integer.MAX_VALUE, time.milliseconds());
+        Map<Integer, List<ProducerBatch>> drained = accum.drain(metadataCache, 
result.readyNodes, Integer.MAX_VALUE, time.milliseconds());
         assertTrue(accum.hasUndrained());
         assertTrue(accum.hasIncomplete());
 
@@ -717,9 +762,9 @@ public class RecordAccumulatorTest {
         }
         for (int i = 0; i < numRecords; i++)
             accum.append(topic, i % 3, 0L, key, value, null, new 
TestCallback(), maxBlockTimeMs, false, time.milliseconds(), cluster);
-        RecordAccumulator.ReadyCheckResult result = accum.ready(metadataMock, 
time.milliseconds());
+        RecordAccumulator.ReadyCheckResult result = accum.ready(metadataCache, 
time.milliseconds());
         assertFalse(result.readyNodes.isEmpty());
-        Map<Integer, List<ProducerBatch>> drained = accum.drain(metadataMock, 
result.readyNodes, Integer.MAX_VALUE,
+        Map<Integer, List<ProducerBatch>> drained = accum.drain(metadataCache, 
result.readyNodes, Integer.MAX_VALUE,
                 time.milliseconds());
         assertTrue(accum.hasUndrained());
         assertTrue(accum.hasIncomplete());
@@ -756,10 +801,10 @@ public class RecordAccumulatorTest {
             if (time.milliseconds() < System.currentTimeMillis())
                 time.setCurrentTimeMs(System.currentTimeMillis());
             accum.append(topic, partition1, 0L, key, value, 
Record.EMPTY_HEADERS, null, maxBlockTimeMs, false, time.milliseconds(), 
cluster);
-            assertEquals(0, accum.ready(metadataMock, 
time.milliseconds()).readyNodes.size(), "No partition should be ready.");
+            assertEquals(0, accum.ready(metadataCache, 
time.milliseconds()).readyNodes.size(), "No partition should be ready.");
 
             time.sleep(lingerMs);
-            readyNodes = accum.ready(metadataMock, 
time.milliseconds()).readyNodes;
+            readyNodes = accum.ready(metadataCache, 
time.milliseconds()).readyNodes;
             assertEquals(Collections.singleton(node1), readyNodes, "Our 
partition's leader should be ready");
 
             expiredBatches = accum.expiredBatches(time.milliseconds());
@@ -774,7 +819,7 @@ public class RecordAccumulatorTest {
             time.sleep(deliveryTimeoutMs - lingerMs);
             expiredBatches = accum.expiredBatches(time.milliseconds());
             assertEquals(1, expiredBatches.size(), "The batch may expire when 
the partition is muted");
-            assertEquals(0, accum.ready(metadataMock, 
time.milliseconds()).readyNodes.size(), "No partitions should be ready.");
+            assertEquals(0, accum.ready(metadataCache, 
time.milliseconds()).readyNodes.size(), "No partitions should be ready.");
         }
     }
 
@@ -805,11 +850,11 @@ public class RecordAccumulatorTest {
         // Test batches not in retry
         for (int i = 0; i < appends; i++) {
             accum.append(topic, partition1, 0L, key, value, 
Record.EMPTY_HEADERS, null, maxBlockTimeMs, false, time.milliseconds(), 
cluster);
-            assertEquals(0, accum.ready(metadataMock, 
time.milliseconds()).readyNodes.size(), "No partitions should be ready.");
+            assertEquals(0, accum.ready(metadataCache, 
time.milliseconds()).readyNodes.size(), "No partitions should be ready.");
         }
         // Make the batches ready due to batch full
         accum.append(topic, partition1, 0L, key, value, Record.EMPTY_HEADERS, 
null, 0, false, time.milliseconds(), cluster);
-        Set<Node> readyNodes = accum.ready(metadataMock, 
time.milliseconds()).readyNodes;
+        Set<Node> readyNodes = accum.ready(metadataCache, 
time.milliseconds()).readyNodes;
         assertEquals(Collections.singleton(node1), readyNodes, "Our 
partition's leader should be ready");
         // Advance the clock to expire the batch.
         time.sleep(deliveryTimeoutMs + 1);
@@ -820,7 +865,7 @@ public class RecordAccumulatorTest {
         accum.unmutePartition(tp1);
         expiredBatches = accum.expiredBatches(time.milliseconds());
         assertEquals(0, expiredBatches.size(), "All batches should have been 
expired earlier");
-        assertEquals(0, accum.ready(metadataMock, 
time.milliseconds()).readyNodes.size(), "No partitions should be ready.");
+        assertEquals(0, accum.ready(metadataCache, 
time.milliseconds()).readyNodes.size(), "No partitions should be ready.");
 
         // Advance the clock to make the next batch ready due to linger.ms
         time.sleep(lingerMs);
@@ -834,15 +879,15 @@ public class RecordAccumulatorTest {
         accum.unmutePartition(tp1);
         expiredBatches = accum.expiredBatches(time.milliseconds());
         assertEquals(0, expiredBatches.size(), "All batches should have been 
expired");
-        assertEquals(0, accum.ready(metadataMock, 
time.milliseconds()).readyNodes.size(), "No partitions should be ready.");
+        assertEquals(0, accum.ready(metadataCache, 
time.milliseconds()).readyNodes.size(), "No partitions should be ready.");
 
         // Test batches in retry.
         // Create a retried batch
         accum.append(topic, partition1, 0L, key, value, Record.EMPTY_HEADERS, 
null, 0, false, time.milliseconds(), cluster);
         time.sleep(lingerMs);
-        readyNodes = accum.ready(metadataMock, time.milliseconds()).readyNodes;
+        readyNodes = accum.ready(metadataCache, 
time.milliseconds()).readyNodes;
         assertEquals(Collections.singleton(node1), readyNodes, "Our 
partition's leader should be ready");
-        Map<Integer, List<ProducerBatch>> drained = accum.drain(metadataMock, 
readyNodes, Integer.MAX_VALUE, time.milliseconds());
+        Map<Integer, List<ProducerBatch>> drained = accum.drain(metadataCache, 
readyNodes, Integer.MAX_VALUE, time.milliseconds());
         assertEquals(drained.get(node1.id()).size(), 1, "There should be only 
one batch.");
         time.sleep(1000L);
         accum.reenqueue(drained.get(node1.id()).get(0), time.milliseconds());
@@ -864,7 +909,7 @@ public class RecordAccumulatorTest {
         // Test that when being throttled muted batches are expired before the 
throttle time is over.
         accum.append(topic, partition1, 0L, key, value, Record.EMPTY_HEADERS, 
null, 0, false, time.milliseconds(), cluster);
         time.sleep(lingerMs);
-        readyNodes = accum.ready(metadataMock, time.milliseconds()).readyNodes;
+        readyNodes = accum.ready(metadataCache, 
time.milliseconds()).readyNodes;
         assertEquals(Collections.singleton(node1), readyNodes, "Our 
partition's leader should be ready");
         // Advance the clock to expire the batch.
         time.sleep(requestTimeout + 1);
@@ -882,7 +927,7 @@ public class RecordAccumulatorTest {
         time.sleep(throttleTimeMs);
         expiredBatches = accum.expiredBatches(time.milliseconds());
         assertEquals(0, expiredBatches.size(), "All batches should have been 
expired earlier");
-        assertEquals(1, accum.ready(metadataMock, 
time.milliseconds()).readyNodes.size(), "No partitions should be ready.");
+        assertEquals(1, accum.ready(metadataCache, 
time.milliseconds()).readyNodes.size(), "No partitions should be ready.");
     }
 
     @Test
@@ -896,28 +941,28 @@ public class RecordAccumulatorTest {
         int appends = expectedNumAppends(batchSize);
         for (int i = 0; i < appends; i++) {
             accum.append(topic, partition1, 0L, key, value, 
Record.EMPTY_HEADERS, null, maxBlockTimeMs, false, time.milliseconds(), 
cluster);
-            assertEquals(0, accum.ready(metadataMock, now).readyNodes.size(), 
"No partitions should be ready.");
+            assertEquals(0, accum.ready(metadataCache, now).readyNodes.size(), 
"No partitions should be ready.");
         }
         time.sleep(2000);
 
         // Test ready with muted partition
         accum.mutePartition(tp1);
-        RecordAccumulator.ReadyCheckResult result = accum.ready(metadataMock, 
time.milliseconds());
+        RecordAccumulator.ReadyCheckResult result = accum.ready(metadataCache, 
time.milliseconds());
         assertEquals(0, result.readyNodes.size(), "No node should be ready");
 
         // Test ready without muted partition
         accum.unmutePartition(tp1);
-        result = accum.ready(metadataMock, time.milliseconds());
+        result = accum.ready(metadataCache, time.milliseconds());
         assertTrue(result.readyNodes.size() > 0, "The batch should be ready");
 
         // Test drain with muted partition
         accum.mutePartition(tp1);
-        Map<Integer, List<ProducerBatch>> drained = accum.drain(metadataMock, 
result.readyNodes, Integer.MAX_VALUE, time.milliseconds());
+        Map<Integer, List<ProducerBatch>> drained = accum.drain(metadataCache, 
result.readyNodes, Integer.MAX_VALUE, time.milliseconds());
         assertEquals(0, drained.get(node1.id()).size(), "No batch should have 
been drained");
 
         // Test drain without muted partition.
         accum.unmutePartition(tp1);
-        drained = accum.drain(metadataMock, result.readyNodes, 
Integer.MAX_VALUE, time.milliseconds());
+        drained = accum.drain(metadataCache, result.readyNodes, 
Integer.MAX_VALUE, time.milliseconds());
         assertTrue(drained.get(node1.id()).size() > 0, "The batch should have 
been drained.");
     }
 
@@ -967,20 +1012,20 @@ public class RecordAccumulatorTest {
             false, time.milliseconds(), cluster);
         assertTrue(accumulator.hasUndrained());
 
-        RecordAccumulator.ReadyCheckResult firstResult = 
accumulator.ready(metadataMock, time.milliseconds());
+        RecordAccumulator.ReadyCheckResult firstResult = 
accumulator.ready(metadataCache, time.milliseconds());
         assertEquals(0, firstResult.readyNodes.size());
-        Map<Integer, List<ProducerBatch>> firstDrained = 
accumulator.drain(metadataMock, firstResult.readyNodes,
+        Map<Integer, List<ProducerBatch>> firstDrained = 
accumulator.drain(metadataCache, firstResult.readyNodes,
             Integer.MAX_VALUE, time.milliseconds());
         assertEquals(0, firstDrained.size());
 
         // Once the transaction begins completion, then the batch should be 
drained immediately.
         Mockito.when(transactionManager.isCompleting()).thenReturn(true);
 
-        RecordAccumulator.ReadyCheckResult secondResult = 
accumulator.ready(metadataMock, time.milliseconds());
+        RecordAccumulator.ReadyCheckResult secondResult = 
accumulator.ready(metadataCache, time.milliseconds());
         assertEquals(1, secondResult.readyNodes.size());
         Node readyNode = secondResult.readyNodes.iterator().next();
 
-        Map<Integer, List<ProducerBatch>> secondDrained = 
accumulator.drain(metadataMock, secondResult.readyNodes,
+        Map<Integer, List<ProducerBatch>> secondDrained = 
accumulator.drain(metadataCache, secondResult.readyNodes,
             Integer.MAX_VALUE, time.milliseconds());
         assertEquals(Collections.singleton(readyNode.id()), 
secondDrained.keySet());
         List<ProducerBatch> batches = secondDrained.get(readyNode.id());
@@ -1011,16 +1056,16 @@ public class RecordAccumulatorTest {
         // Re-enqueuing counts as a second attempt, so the delay with jitter 
is 100 * (1 + 0.2) + 1
         time.sleep(121L);
         // Drain the batch.
-        RecordAccumulator.ReadyCheckResult result = accum.ready(metadataMock, 
time.milliseconds());
+        RecordAccumulator.ReadyCheckResult result = accum.ready(metadataCache, 
time.milliseconds());
         assertTrue(result.readyNodes.size() > 0, "The batch should be ready");
-        Map<Integer, List<ProducerBatch>> drained = accum.drain(metadataMock, 
result.readyNodes, Integer.MAX_VALUE, time.milliseconds());
+        Map<Integer, List<ProducerBatch>> drained = accum.drain(metadataCache, 
result.readyNodes, Integer.MAX_VALUE, time.milliseconds());
         assertEquals(1, drained.size(), "Only node1 should be drained");
         assertEquals(1, drained.get(node1.id()).size(), "Only one batch should 
be drained");
         // Split and reenqueue the batch.
         accum.splitAndReenqueue(drained.get(node1.id()).get(0));
         time.sleep(101L);
 
-        drained = accum.drain(metadataMock, result.readyNodes, 
Integer.MAX_VALUE, time.milliseconds());
+        drained = accum.drain(metadataCache, result.readyNodes, 
Integer.MAX_VALUE, time.milliseconds());
         assertFalse(drained.isEmpty());
         assertFalse(drained.get(node1.id()).isEmpty());
         drained.get(node1.id()).get(0).complete(acked.get(), 100L);
@@ -1028,7 +1073,7 @@ public class RecordAccumulatorTest {
         assertTrue(future1.isDone());
         assertEquals(0, future1.get().offset());
 
-        drained = accum.drain(metadataMock, result.readyNodes, 
Integer.MAX_VALUE, time.milliseconds());
+        drained = accum.drain(metadataCache, result.readyNodes, 
Integer.MAX_VALUE, time.milliseconds());
         assertFalse(drained.isEmpty());
         assertFalse(drained.get(node1.id()).isEmpty());
         drained.get(node1.id()).get(0).complete(acked.get(), 100L);
@@ -1049,14 +1094,14 @@ public class RecordAccumulatorTest {
         int numSplitBatches = prepareSplitBatches(accum, seed, 100, 20);
         assertTrue(numSplitBatches > 0, "There should be some split batches");
         // Drain all the split batches.
-        RecordAccumulator.ReadyCheckResult result = accum.ready(metadataMock, 
time.milliseconds());
+        RecordAccumulator.ReadyCheckResult result = accum.ready(metadataCache, 
time.milliseconds());
         for (int i = 0; i < numSplitBatches; i++) {
             Map<Integer, List<ProducerBatch>> drained =
-                accum.drain(metadataMock, result.readyNodes, 
Integer.MAX_VALUE, time.milliseconds());
+                accum.drain(metadataCache, result.readyNodes, 
Integer.MAX_VALUE, time.milliseconds());
             assertFalse(drained.isEmpty());
             assertFalse(drained.get(node1.id()).isEmpty());
         }
-        assertTrue(accum.ready(metadataMock, 
time.milliseconds()).readyNodes.isEmpty(), "All the batches should have been 
drained.");
+        assertTrue(accum.ready(metadataCache, 
time.milliseconds()).readyNodes.isEmpty(), "All the batches should have been 
drained.");
         assertEquals(bufferCapacity, accum.bufferPoolAvailableMemory(),
             "The split batches should be allocated off the accumulator");
     }
@@ -1103,16 +1148,16 @@ public class RecordAccumulatorTest {
             batchSize + DefaultRecordBatch.RECORD_BATCH_OVERHEAD, 10 * 
batchSize, CompressionType.NONE, lingerMs);
 
         accum.append(topic, partition1, 0L, key, value, Record.EMPTY_HEADERS, 
null, maxBlockTimeMs, false, time.milliseconds(), cluster);
-        Set<Node> readyNodes = accum.ready(metadataMock, 
time.milliseconds()).readyNodes;
-        Map<Integer, List<ProducerBatch>> drained = accum.drain(metadataMock, 
readyNodes, Integer.MAX_VALUE, time.milliseconds());
+        Set<Node> readyNodes = accum.ready(metadataCache, 
time.milliseconds()).readyNodes;
+        Map<Integer, List<ProducerBatch>> drained = accum.drain(metadataCache, 
readyNodes, Integer.MAX_VALUE, time.milliseconds());
         assertTrue(drained.isEmpty());
         //assertTrue(accum.soonToExpireInFlightBatches().isEmpty());
 
         // advanced clock and send one batch out but it should not be included 
in soon to expire inflight
         // batches because batch's expiry is quite far.
         time.sleep(lingerMs + 1);
-        readyNodes = accum.ready(metadataMock, time.milliseconds()).readyNodes;
-        drained = accum.drain(metadataMock, readyNodes, Integer.MAX_VALUE, 
time.milliseconds());
+        readyNodes = accum.ready(metadataCache, 
time.milliseconds()).readyNodes;
+        drained = accum.drain(metadataCache, readyNodes, Integer.MAX_VALUE, 
time.milliseconds());
         assertEquals(1, drained.size(), "A batch did not drain after linger");
         //assertTrue(accum.soonToExpireInFlightBatches().isEmpty());
 
@@ -1121,8 +1166,8 @@ public class RecordAccumulatorTest {
         time.sleep(lingerMs * 4);
 
         // Now drain and check that accumulator picked up the drained batch 
because its expiry is soon.
-        readyNodes = accum.ready(metadataMock, time.milliseconds()).readyNodes;
-        drained = accum.drain(metadataMock, readyNodes, Integer.MAX_VALUE, 
time.milliseconds());
+        readyNodes = accum.ready(metadataCache, 
time.milliseconds()).readyNodes;
+        drained = accum.drain(metadataCache, readyNodes, Integer.MAX_VALUE, 
time.milliseconds());
         assertEquals(1, drained.size(), "A batch did not drain after linger");
     }
 
@@ -1144,9 +1189,9 @@ public class RecordAccumulatorTest {
         for (Boolean mute : muteStates) {
             accum.append(topic, partition1, 0L, key, value, 
Record.EMPTY_HEADERS, null, 0, false, time.milliseconds(), cluster);
             time.sleep(lingerMs);
-            readyNodes = accum.ready(metadataMock, 
time.milliseconds()).readyNodes;
+            readyNodes = accum.ready(metadataCache, 
time.milliseconds()).readyNodes;
             assertEquals(Collections.singleton(node1), readyNodes, "Our 
partition's leader should be ready");
-            Map<Integer, List<ProducerBatch>> drained = 
accum.drain(metadataMock, readyNodes, Integer.MAX_VALUE, time.milliseconds());
+            Map<Integer, List<ProducerBatch>> drained = 
accum.drain(metadataCache, readyNodes, Integer.MAX_VALUE, time.milliseconds());
             assertEquals(1, drained.get(node1.id()).size(), "There should be 
only one batch.");
             time.sleep(rtt);
             accum.reenqueue(drained.get(node1.id()).get(0), 
time.milliseconds());
@@ -1158,7 +1203,7 @@ public class RecordAccumulatorTest {
 
             // test expiration
             time.sleep(deliveryTimeoutMs - rtt);
-            accum.drain(metadataMock, Collections.singleton(node1), 
Integer.MAX_VALUE, time.milliseconds());
+            accum.drain(metadataCache, Collections.singleton(node1), 
Integer.MAX_VALUE, time.milliseconds());
             expiredBatches = accum.expiredBatches(time.milliseconds());
             assertEquals(mute ? 1 : 0, expiredBatches.size(), 
"RecordAccumulator has expired batches if the partition is not muted");
         }
@@ -1199,12 +1244,12 @@ public class RecordAccumulatorTest {
             // We only appended if we do not retry.
             if (!switchPartition) {
                 appends++;
-                assertEquals(0, accum.ready(metadataMock, 
now).readyNodes.size(), "No partitions should be ready.");
+                assertEquals(0, accum.ready(metadataCache, 
now).readyNodes.size(), "No partitions should be ready.");
             }
         }
 
         // Batch should be full.
-        assertEquals(1, accum.ready(metadataMock, 
time.milliseconds()).readyNodes.size());
+        assertEquals(1, accum.ready(metadataCache, 
time.milliseconds()).readyNodes.size());
         assertEquals(appends, expectedAppends);
         switchPartition = false;
 
@@ -1263,6 +1308,12 @@ public class RecordAccumulatorTest {
                 }
             };
 
+            PartitionInfo part1 = 
MetadataResponse.toPartitionInfo(partMetadata1, nodes);
+            PartitionInfo part2 = 
MetadataResponse.toPartitionInfo(partMetadata2, nodes);
+            PartitionInfo part3 = 
MetadataResponse.toPartitionInfo(partMetadata3, nodes);
+            Cluster cluster = new Cluster(null, asList(node1, node2), 
asList(part1, part2, part3),
+                Collections.emptySet(), Collections.emptySet());
+
             // Produce small record, we should switch to first partition.
             accum.append(topic, RecordMetadata.UNKNOWN_PARTITION, 0L, null, 
value, Record.EMPTY_HEADERS,
                 callbacks, maxBlockTimeMs, false, time.milliseconds(), 
cluster);
@@ -1332,7 +1383,7 @@ public class RecordAccumulatorTest {
             }
 
             // Let the accumulator generate the probability tables.
-            accum.ready(metadataMock, time.milliseconds());
+            accum.ready(metadataCache, time.milliseconds());
 
             // Set up callbacks so that we know what partition is chosen.
             final AtomicInteger partition = new 
AtomicInteger(RecordMetadata.UNKNOWN_PARTITION);
@@ -1376,7 +1427,7 @@ public class RecordAccumulatorTest {
             // Test that partitions residing on high-latency nodes don't get 
switched to.
             accum.updateNodeLatencyStats(0, time.milliseconds() - 200, true);
             accum.updateNodeLatencyStats(0, time.milliseconds(), false);
-            accum.ready(metadataMock, time.milliseconds());
+            accum.ready(metadataCache, time.milliseconds());
 
             // Do one append, because partition gets switched after append.
             accum.append(topic, RecordMetadata.UNKNOWN_PARTITION, 0L, null, 
largeValue, Record.EMPTY_HEADERS,
@@ -1413,9 +1464,9 @@ public class RecordAccumulatorTest {
             time.sleep(10);
 
             // We should have one batch ready.
-            Set<Node> nodes = accum.ready(metadataMock, 
time.milliseconds()).readyNodes;
+            Set<Node> nodes = accum.ready(metadataCache, 
time.milliseconds()).readyNodes;
             assertEquals(1, nodes.size(), "Should have 1 leader ready");
-            List<ProducerBatch> batches = accum.drain(metadataMock, nodes, 
Integer.MAX_VALUE, 0).entrySet().iterator().next().getValue();
+            List<ProducerBatch> batches = accum.drain(metadataCache, nodes, 
Integer.MAX_VALUE, 0).entrySet().iterator().next().getValue();
             assertEquals(1, batches.size(), "Should have 1 batch ready");
             int actualBatchSize = batches.get(0).records().sizeInBytes();
             assertTrue(actualBatchSize > batchSize / 2, "Batch must be greater 
than half batch.size");
@@ -1431,12 +1482,9 @@ public class RecordAccumulatorTest {
     @Test
     public void testReadyAndDrainWhenABatchIsBeingRetried() throws 
InterruptedException {
         int part1LeaderEpoch = 100;
-        // Create cluster metadata, partition1 being hosted by node1.
-        part1 = new PartitionInfo(topic, partition1, node1, null, null, null);
-        cluster = new Cluster(null, Arrays.asList(node1, node2), 
Arrays.asList(part1),
-            Collections.emptySet(), Collections.emptySet());
-        final int finalEpoch = part1LeaderEpoch;
-        metadataMock = setupMetadata(cluster, tp -> finalEpoch);
+        // Create cluster metadata, partition1 being hosted by node1
+        PartitionMetadata part1Metadata = new PartitionMetadata(Errors.NONE, 
tp1, Optional.of(node1.id()),  Optional.of(part1LeaderEpoch), null, null, null);
+        MetadataSnapshot metadataCache = new MetadataSnapshot(null, nodes, 
Arrays.asList(part1Metadata), Collections.emptySet(), Collections.emptySet(), 
Collections.emptySet(), null, Collections.emptyMap());
 
         int batchSize = 10;
         int lingerMs = 10;
@@ -1457,14 +1505,14 @@ public class RecordAccumulatorTest {
         // 1st attempt(not a retry) to produce batchA, it should be ready & 
drained to be produced.
         {
             now += lingerMs + 1;
-            RecordAccumulator.ReadyCheckResult result = 
accum.ready(metadataMock, now);
+            RecordAccumulator.ReadyCheckResult result = 
accum.ready(metadataCache, now);
             assertTrue(result.readyNodes.contains(node1), "Node1 is ready");
 
-            Map<Integer, List<ProducerBatch>> batches = 
accum.drain(metadataMock,
+            Map<Integer, List<ProducerBatch>> batches = 
accum.drain(metadataCache,
                 result.readyNodes, 999999 /* maxSize */, now);
             assertTrue(batches.containsKey(node1.id()) && 
batches.get(node1.id()).size() == 1, "Node1 has 1 batch ready & drained");
             ProducerBatch batch = batches.get(node1.id()).get(0);
-            assertEquals(Optional.of(part1LeaderEpoch), 
batch.currentLeaderEpoch());
+            assertEquals(OptionalInt.of(part1LeaderEpoch), 
batch.currentLeaderEpoch());
             assertEquals(0, batch.attemptsWhenLeaderLastChanged());
             // Re-enqueue batch for subsequent retries & test-cases
             accum.reenqueue(batch, now);
@@ -1473,11 +1521,11 @@ public class RecordAccumulatorTest {
         // In this retry of batchA, wait-time between retries is less than 
configured and no leader change, so should backoff.
         {
             now += 1;
-            RecordAccumulator.ReadyCheckResult result = 
accum.ready(metadataMock, now);
+            RecordAccumulator.ReadyCheckResult result = 
accum.ready(metadataCache, now);
             assertFalse(result.readyNodes.contains(node1), "Node1 is not 
ready");
 
             // Try to drain from node1, it should return no batches.
-            Map<Integer, List<ProducerBatch>> batches = 
accum.drain(metadataMock,
+            Map<Integer, List<ProducerBatch>> batches = 
accum.drain(metadataCache,
                 new HashSet<>(Arrays.asList(node1)), 999999 /* maxSize */, 
now);
             assertTrue(batches.containsKey(node1.id()) && 
batches.get(node1.id()).isEmpty(),
                 "No batches ready to be drained on Node1");
@@ -1488,19 +1536,16 @@ public class RecordAccumulatorTest {
             now += 1;
             part1LeaderEpoch++;
             // Create cluster metadata, with new leader epoch.
-            part1 = new PartitionInfo(topic, partition1, node1, null, null, 
null);
-            cluster = new Cluster(null, Arrays.asList(node1, node2), 
Arrays.asList(part1),
-                Collections.emptySet(), Collections.emptySet());
-            final int finalPart1LeaderEpoch = part1LeaderEpoch;
-            metadataMock = setupMetadata(cluster, tp -> finalPart1LeaderEpoch);
-            RecordAccumulator.ReadyCheckResult result = 
accum.ready(metadataMock, now);
+            part1Metadata = new PartitionMetadata(Errors.NONE, tp1, 
Optional.of(node1.id()),  Optional.of(part1LeaderEpoch), null, null, null);
+            metadataCache = new MetadataSnapshot(null, nodes, 
Arrays.asList(part1Metadata), Collections.emptySet(), Collections.emptySet(), 
Collections.emptySet(), null, Collections.emptyMap());
+            RecordAccumulator.ReadyCheckResult result = 
accum.ready(metadataCache, now);
             assertTrue(result.readyNodes.contains(node1), "Node1 is ready");
 
-            Map<Integer, List<ProducerBatch>> batches = 
accum.drain(metadataMock,
+            Map<Integer, List<ProducerBatch>> batches = 
accum.drain(metadataCache,
                 result.readyNodes, 999999 /* maxSize */, now);
             assertTrue(batches.containsKey(node1.id()) && 
batches.get(node1.id()).size() == 1, "Node1 has 1 batch ready & drained");
             ProducerBatch batch = batches.get(node1.id()).get(0);
-            assertEquals(Optional.of(part1LeaderEpoch), 
batch.currentLeaderEpoch());
+            assertEquals(OptionalInt.of(part1LeaderEpoch), 
batch.currentLeaderEpoch());
             assertEquals(1, batch.attemptsWhenLeaderLastChanged());
 
             // Re-enqueue batch for subsequent retries/test-cases.
@@ -1511,19 +1556,16 @@ public class RecordAccumulatorTest {
         {
             now += 2 * retryBackoffMaxMs;
             // Create cluster metadata, with new leader epoch.
-            part1 = new PartitionInfo(topic, partition1, node1, null, null, 
null);
-            cluster = new Cluster(null, Arrays.asList(node1, node2), 
Arrays.asList(part1),
-                Collections.emptySet(), Collections.emptySet());
-            final int finalPart1LeaderEpoch = part1LeaderEpoch;
-            metadataMock = setupMetadata(cluster, tp -> finalPart1LeaderEpoch);
-            RecordAccumulator.ReadyCheckResult result = 
accum.ready(metadataMock, now);
+            part1Metadata = new PartitionMetadata(Errors.NONE, tp1, 
Optional.of(node1.id()),  Optional.of(part1LeaderEpoch), null, null, null);
+            metadataCache = new MetadataSnapshot(null, nodes, 
Arrays.asList(part1Metadata), Collections.emptySet(), Collections.emptySet(), 
Collections.emptySet(), null, Collections.emptyMap());
+            RecordAccumulator.ReadyCheckResult result = 
accum.ready(metadataCache, now);
             assertTrue(result.readyNodes.contains(node1), "Node1 is ready");
 
-            Map<Integer, List<ProducerBatch>> batches = 
accum.drain(metadataMock,
+            Map<Integer, List<ProducerBatch>> batches = 
accum.drain(metadataCache,
                 result.readyNodes, 999999 /* maxSize */, now);
             assertTrue(batches.containsKey(node1.id()) && 
batches.get(node1.id()).size() == 1, "Node1 has 1 batch ready & drained");
             ProducerBatch batch = batches.get(node1.id()).get(0);
-            assertEquals(Optional.of(part1LeaderEpoch), 
batch.currentLeaderEpoch());
+            assertEquals(OptionalInt.of(part1LeaderEpoch), 
batch.currentLeaderEpoch());
             assertEquals(1, batch.attemptsWhenLeaderLastChanged());
 
             // Re-enqueue batch for subsequent retries/test-cases.
@@ -1535,19 +1577,16 @@ public class RecordAccumulatorTest {
             now += 2 * retryBackoffMaxMs;
             part1LeaderEpoch++;
             // Create cluster metadata, with new leader epoch.
-            part1 = new PartitionInfo(topic, partition1, node1, null, null, 
null);
-            cluster = new Cluster(null, Arrays.asList(node1, node2), 
Arrays.asList(part1),
-                Collections.emptySet(), Collections.emptySet());
-            final int finalPart1LeaderEpoch = part1LeaderEpoch;
-            metadataMock = setupMetadata(cluster, tp -> finalPart1LeaderEpoch);
-            RecordAccumulator.ReadyCheckResult result = 
accum.ready(metadataMock, now);
+            part1Metadata = new PartitionMetadata(Errors.NONE, tp1, 
Optional.of(node1.id()),  Optional.of(part1LeaderEpoch), null, null, null);
+            metadataCache = new MetadataSnapshot(null, nodes, 
Arrays.asList(part1Metadata), Collections.emptySet(), Collections.emptySet(), 
Collections.emptySet(), null, Collections.emptyMap());
+            RecordAccumulator.ReadyCheckResult result = 
accum.ready(metadataCache, now);
             assertTrue(result.readyNodes.contains(node1), "Node1 is ready");
 
-            Map<Integer, List<ProducerBatch>> batches = 
accum.drain(metadataMock,
+            Map<Integer, List<ProducerBatch>> batches = 
accum.drain(metadataCache,
                 result.readyNodes, 999999 /* maxSize */, now);
             assertTrue(batches.containsKey(node1.id()) && 
batches.get(node1.id()).size() == 1, "Node1 has 1 batch ready & drained");
             ProducerBatch batch = batches.get(node1.id()).get(0);
-            assertEquals(Optional.of(part1LeaderEpoch), 
batch.currentLeaderEpoch());
+            assertEquals(OptionalInt.of(part1LeaderEpoch), 
batch.currentLeaderEpoch());
             assertEquals(3, batch.attemptsWhenLeaderLastChanged());
 
             // Re-enqueue batch for subsequent retries/test-cases.
@@ -1564,97 +1603,15 @@ public class RecordAccumulatorTest {
             CompressionType.NONE, lingerMs);
 
         // Create cluster metadata, node2 doesn't host any partitions.
-        part1 = new PartitionInfo(topic, partition1, node1, null, null, null);
-        cluster = new Cluster(null, Arrays.asList(node1, node2), 
Arrays.asList(part1),
-            Collections.emptySet(), Collections.emptySet());
-        metadataMock = Mockito.mock(Metadata.class);
-        Mockito.when(metadataMock.fetch()).thenReturn(cluster);
-        Mockito.when(metadataMock.currentLeader(tp1)).thenReturn(
-            new Metadata.LeaderAndEpoch(Optional.of(node1),
-                Optional.of(999 /* dummy value */)));
+        PartitionMetadata part1Metadata = new PartitionMetadata(Errors.NONE, 
tp1, Optional.of(node1.id()), Optional.empty(), null, null, null);
+        MetadataSnapshot metadataCache = new MetadataSnapshot(null, nodes, 
Arrays.asList(part1Metadata), Collections.emptySet(), Collections.emptySet(), 
Collections.emptySet(), null, Collections.emptyMap());
 
         // Drain for node2, it should return 0 batches,
-        Map<Integer, List<ProducerBatch>> batches = accum.drain(metadataMock,
+        Map<Integer, List<ProducerBatch>> batches = accum.drain(metadataCache,
             new HashSet<>(Arrays.asList(node2)), 999999 /* maxSize */, 
time.milliseconds());
         assertTrue(batches.get(node2.id()).isEmpty());
     }
 
-    @Test
-    public void testDrainOnANodeWhenItCeasesToBeALeader() throws 
InterruptedException {
-        int batchSize = 10;
-        int lingerMs = 10;
-        long totalSize = 10 * 1024;
-        RecordAccumulator accum = createTestRecordAccumulator(batchSize, 
totalSize,
-            CompressionType.NONE, lingerMs);
-
-        // While node1 is being drained, leader changes from node1 -> node2 
for a partition.
-        {
-            // Create cluster metadata, partition1&2 being hosted by node1&2 
resp.
-            part1 = new PartitionInfo(topic, partition1, node1, null, null, 
null);
-            part2 = new PartitionInfo(topic, partition2, node2, null, null, 
null);
-            cluster = new Cluster(null, Arrays.asList(node1, node2), 
Arrays.asList(part1, part2),
-                Collections.emptySet(), Collections.emptySet());
-            metadataMock = Mockito.mock(Metadata.class);
-            Mockito.when(metadataMock.fetch()).thenReturn(cluster);
-            // But metadata has a newer leader for partition1 i.e node2.
-            Mockito.when(metadataMock.currentLeader(tp1)).thenReturn(
-                new Metadata.LeaderAndEpoch(Optional.of(node2),
-                    Optional.of(999 /* dummy value */)));
-            Mockito.when(metadataMock.currentLeader(tp2)).thenReturn(
-                new Metadata.LeaderAndEpoch(Optional.of(node2),
-                    Optional.of(999 /* dummy value */)));
-
-            // Create 1 batch each for partition1 & partition2.
-            long now = time.milliseconds();
-            accum.append(topic, partition1, 0L, key, value, 
Record.EMPTY_HEADERS, null,
-                maxBlockTimeMs, false, now, cluster);
-            accum.append(topic, partition2, 0L, key, value, 
Record.EMPTY_HEADERS, null,
-                maxBlockTimeMs, false, now, cluster);
-
-            // Drain for node1, it should return 0 batches, as partition1's 
leader in metadata changed.
-            // Drain for node2, it should return 1 batch, for partition2.
-            Map<Integer, List<ProducerBatch>> batches = 
accum.drain(metadataMock,
-                new HashSet<>(Arrays.asList(node1, node2)), 999999 /* maxSize 
*/, now);
-            assertTrue(batches.get(node1.id()).isEmpty());
-            assertEquals(1, batches.get(node2.id()).size());
-        }
-
-        // Cleanup un-drained batches to have an empty accum before next test.
-        accum.abortUndrainedBatches(new RuntimeException());
-
-        // While node1 is being drained, leader changes from node1 -> 
"no-leader" for partition.
-        {
-            // Create cluster metadata, partition1&2 being hosted by node1&2 
resp.
-            part1 = new PartitionInfo(topic, partition1, node1, null, null, 
null);
-            part2 = new PartitionInfo(topic, partition2, node2, null, null, 
null);
-            cluster = new Cluster(null, Arrays.asList(node1, node2), 
Arrays.asList(part1, part2),
-                Collections.emptySet(), Collections.emptySet());
-            metadataMock = Mockito.mock(Metadata.class);
-            Mockito.when(metadataMock.fetch()).thenReturn(cluster);
-            // But metadata no longer has a leader for partition1.
-            Mockito.when(metadataMock.currentLeader(tp1)).thenReturn(
-                new Metadata.LeaderAndEpoch(Optional.empty(),
-                    Optional.of(999 /* dummy value */)));
-            Mockito.when(metadataMock.currentLeader(tp2)).thenReturn(
-                new Metadata.LeaderAndEpoch(Optional.of(node2),
-                    Optional.of(999 /* dummy value */)));
-
-            // Create 1 batch each for partition1 & partition2.
-            long now = time.milliseconds();
-            accum.append(topic, partition1, 0L, key, value, 
Record.EMPTY_HEADERS, null,
-                maxBlockTimeMs, false, now, cluster);
-            accum.append(topic, partition2, 0L, key, value, 
Record.EMPTY_HEADERS, null,
-                maxBlockTimeMs, false, now, cluster);
-
-            // Drain for node1, it should return 0 batches, as partition1's 
leader in metadata changed.
-            // Drain for node2, it should return 1 batch, for partition2.
-            Map<Integer, List<ProducerBatch>> batches = 
accum.drain(metadataMock,
-                new HashSet<>(Arrays.asList(node1, node2)), 999999 /* maxSize 
*/, now);
-            assertTrue(batches.get(node1.id()).isEmpty());
-            assertEquals(1, batches.get(node2.id()).size());
-        }
-    }
-
     private int prepareSplitBatches(RecordAccumulator accum, long seed, int 
recordSize, int numRecords)
         throws InterruptedException {
         Random random = new Random();
@@ -1667,9 +1624,9 @@ public class RecordAccumulatorTest {
             accum.append(topic, partition1, 0L, null, 
bytesWithPoorCompression(random, recordSize), Record.EMPTY_HEADERS, null, 0, 
false, time.milliseconds(), cluster);
         }
 
-        RecordAccumulator.ReadyCheckResult result = accum.ready(metadataMock, 
time.milliseconds());
+        RecordAccumulator.ReadyCheckResult result = accum.ready(metadataCache, 
time.milliseconds());
         assertFalse(result.readyNodes.isEmpty());
-        Map<Integer, List<ProducerBatch>> batches = accum.drain(metadataMock, 
result.readyNodes, Integer.MAX_VALUE, time.milliseconds());
+        Map<Integer, List<ProducerBatch>> batches = accum.drain(metadataCache, 
result.readyNodes, Integer.MAX_VALUE, time.milliseconds());
         assertEquals(1, batches.size());
         assertEquals(1, batches.values().iterator().next().size());
         ProducerBatch batch = batches.values().iterator().next().get(0);
@@ -1685,8 +1642,8 @@ public class RecordAccumulatorTest {
         boolean batchDrained;
         do {
             batchDrained = false;
-            RecordAccumulator.ReadyCheckResult result = 
accum.ready(metadataMock, time.milliseconds());
-            Map<Integer, List<ProducerBatch>> batches = 
accum.drain(metadataMock, result.readyNodes, Integer.MAX_VALUE, 
time.milliseconds());
+            RecordAccumulator.ReadyCheckResult result = 
accum.ready(metadataCache, time.milliseconds());
+            Map<Integer, List<ProducerBatch>> batches = 
accum.drain(metadataCache, result.readyNodes, Integer.MAX_VALUE, 
time.milliseconds());
             for (List<ProducerBatch> batchList : batches.values()) {
                 for (ProducerBatch batch : batchList) {
                     batchDrained = true;
@@ -1805,27 +1762,4 @@ public class RecordAccumulatorTest {
             txnManager,
             new BufferPool(totalSize, batchSize, metrics, time, 
metricGrpName));
     }
-
-    /**
-     * Setup a mocked metadata object.
-     */
-    private Metadata setupMetadata(Cluster cluster) {
-        return setupMetadata(cluster, tp -> 999 /* dummy epoch */);
-    }
-
-    /**
-     * Setup a mocked metadata object.
-     */
-    private Metadata setupMetadata(Cluster cluster, final 
Function<TopicPartition, Integer> epochSupplier) {
-        Metadata metadataMock = Mockito.mock(Metadata.class);
-        Mockito.when(metadataMock.fetch()).thenReturn(cluster);
-        for (String topic: cluster.topics()) {
-            for (PartitionInfo partInfo: cluster.partitionsForTopic(topic)) {
-                TopicPartition tp = new TopicPartition(partInfo.topic(), 
partInfo.partition());
-                Integer partLeaderEpoch = epochSupplier.apply(tp);
-                Mockito.when(metadataMock.currentLeader(tp)).thenReturn(new 
Metadata.LeaderAndEpoch(Optional.of(partInfo.leader()), 
Optional.of(partLeaderEpoch)));
-            }
-        }
-        return metadataMock;
-    }
 }
diff --git 
a/clients/src/test/java/org/apache/kafka/clients/producer/internals/SenderTest.java
 
b/clients/src/test/java/org/apache/kafka/clients/producer/internals/SenderTest.java
index a524f2f78dc..9af3aff8bbf 100644
--- 
a/clients/src/test/java/org/apache/kafka/clients/producer/internals/SenderTest.java
+++ 
b/clients/src/test/java/org/apache/kafka/clients/producer/internals/SenderTest.java
@@ -20,6 +20,7 @@ import org.apache.kafka.clients.ApiVersions;
 import org.apache.kafka.clients.ClientRequest;
 import org.apache.kafka.clients.ClientResponse;
 import org.apache.kafka.clients.Metadata;
+import org.apache.kafka.clients.MetadataSnapshot;
 import org.apache.kafka.clients.MockClient;
 import org.apache.kafka.clients.NetworkClient;
 import org.apache.kafka.clients.NodeApiVersions;
@@ -471,7 +472,7 @@ public class SenderTest {
         final byte[] key = "key".getBytes();
         final byte[] value = "value".getBytes();
         final long maxBlockTimeMs = 1000;
-        Cluster cluster = TestUtils.singletonCluster();
+        MetadataSnapshot metadataCache = TestUtils.metadataSnapshotWith(1);
         RecordAccumulator.AppendCallbacks callbacks = new 
RecordAccumulator.AppendCallbacks() {
             @Override
             public void setPartition(int partition) {
@@ -483,7 +484,7 @@ public class SenderTest {
                     expiryCallbackCount.incrementAndGet();
                     try {
                         accumulator.append(tp1.topic(), tp1.partition(), 0L, 
key, value,
-                            Record.EMPTY_HEADERS, null, maxBlockTimeMs, false, 
time.milliseconds(), cluster);
+                            Record.EMPTY_HEADERS, null, maxBlockTimeMs, false, 
time.milliseconds(), metadataCache.cluster());
                     } catch (InterruptedException e) {
                         throw new RuntimeException("Unexpected interruption", 
e);
                     }
@@ -494,14 +495,14 @@ public class SenderTest {
 
         final long nowMs = time.milliseconds();
         for (int i = 0; i < messagesPerBatch; i++)
-            accumulator.append(tp1.topic(), tp1.partition(), 0L, key, value, 
null, callbacks, maxBlockTimeMs, false, nowMs, cluster);
+            accumulator.append(tp1.topic(), tp1.partition(), 0L, key, value, 
null, callbacks, maxBlockTimeMs, false, nowMs, metadataCache.cluster());
 
         // Advance the clock to expire the first batch.
         time.sleep(10000);
 
         Node clusterNode = metadata.fetch().nodes().get(0);
         Map<Integer, List<ProducerBatch>> drainedBatches =
-            accumulator.drain(metadata, Collections.singleton(clusterNode), 
Integer.MAX_VALUE, time.milliseconds());
+            accumulator.drain(metadataCache, 
Collections.singleton(clusterNode), Integer.MAX_VALUE, time.milliseconds());
         sender.addToInflightBatches(drainedBatches);
 
         // Disconnect the target node for the pending produce request. This 
will ensure that sender will try to
diff --git 
a/clients/src/test/java/org/apache/kafka/clients/producer/internals/TransactionManagerTest.java
 
b/clients/src/test/java/org/apache/kafka/clients/producer/internals/TransactionManagerTest.java
index b615d1c5356..a65beeebb81 100644
--- 
a/clients/src/test/java/org/apache/kafka/clients/producer/internals/TransactionManagerTest.java
+++ 
b/clients/src/test/java/org/apache/kafka/clients/producer/internals/TransactionManagerTest.java
@@ -18,6 +18,7 @@ package org.apache.kafka.clients.producer.internals;
 
 import org.apache.kafka.clients.ApiVersions;
 import org.apache.kafka.clients.Metadata;
+import org.apache.kafka.clients.MetadataSnapshot;
 import org.apache.kafka.clients.MockClient;
 import org.apache.kafka.clients.NodeApiVersions;
 import org.apache.kafka.clients.consumer.CommitFailedException;
@@ -69,6 +70,7 @@ import 
org.apache.kafka.common.requests.FindCoordinatorResponse;
 import org.apache.kafka.common.requests.InitProducerIdRequest;
 import org.apache.kafka.common.requests.InitProducerIdResponse;
 import org.apache.kafka.common.requests.JoinGroupRequest;
+import org.apache.kafka.common.requests.MetadataResponse.PartitionMetadata;
 import org.apache.kafka.common.requests.ProduceRequest;
 import org.apache.kafka.common.requests.ProduceResponse;
 import org.apache.kafka.common.requests.RequestTestUtils;
@@ -2471,16 +2473,16 @@ public class TransactionManagerTest {
 
         Node node1 = new Node(0, "localhost", 1111);
         Node node2 = new Node(1, "localhost", 1112);
-        PartitionInfo part1 = new PartitionInfo(topic, 0, node1, null, null);
-        PartitionInfo part2 = new PartitionInfo(topic, 1, node2, null, null);
-
-        Cluster cluster = new Cluster(null, Arrays.asList(node1, node2), 
Arrays.asList(part1, part2),
-                Collections.emptySet(), Collections.emptySet());
-        Metadata metadataMock = setupMetadata(cluster);
+        Map<Integer, Node> nodesById = new HashMap<>();
+        nodesById.put(node1.id(), node1);
+        nodesById.put(node2.id(), node2);
+        PartitionMetadata part1Metadata = new PartitionMetadata(Errors.NONE, 
tp0, Optional.of(node1.id()), Optional.empty(), null, null, null);
+        PartitionMetadata part2Metadata = new PartitionMetadata(Errors.NONE, 
tp1, Optional.of(node2.id()), Optional.empty(), null, null, null);
+        MetadataSnapshot metadataCache = new MetadataSnapshot(null, nodesById, 
Arrays.asList(part1Metadata, part2Metadata), Collections.emptySet(), 
Collections.emptySet(), Collections.emptySet(), null, Collections.emptyMap());
         Set<Node> nodes = new HashSet<>();
         nodes.add(node1);
         nodes.add(node2);
-        Map<Integer, List<ProducerBatch>> drainedBatches = 
accumulator.drain(metadataMock, nodes, Integer.MAX_VALUE,
+        Map<Integer, List<ProducerBatch>> drainedBatches = 
accumulator.drain(metadataCache, nodes, Integer.MAX_VALUE,
                 time.milliseconds());
 
         // We shouldn't drain batches which haven't been added to the 
transaction yet.
@@ -2506,12 +2508,10 @@ public class TransactionManagerTest {
 
         // Try to drain a message destined for tp1, it should get drained.
         Node node1 = new Node(1, "localhost", 1112);
-        PartitionInfo part1 = new PartitionInfo(topic, 1, node1, null, null);
-        Cluster cluster = new Cluster(null, Collections.singletonList(node1), 
Collections.singletonList(part1),
-                Collections.emptySet(), Collections.emptySet());
-        Metadata metadataMock = setupMetadata(cluster);
+        PartitionMetadata part1Metadata = new PartitionMetadata(Errors.NONE, 
tp1, Optional.of(node1.id()), Optional.empty(), null, null, null);
+        MetadataSnapshot metadataCache = new MetadataSnapshot(null, 
Collections.singletonMap(node1.id(), node1), Arrays.asList(part1Metadata), 
Collections.emptySet(), Collections.emptySet(), Collections.emptySet(), null, 
Collections.emptyMap());
         appendToAccumulator(tp1);
-        Map<Integer, List<ProducerBatch>> drainedBatches = 
accumulator.drain(metadataMock, Collections.singleton(node1),
+        Map<Integer, List<ProducerBatch>> drainedBatches = 
accumulator.drain(metadataCache, Collections.singleton(node1),
                 Integer.MAX_VALUE,
                 time.milliseconds());
 
@@ -2529,15 +2529,12 @@ public class TransactionManagerTest {
         // Don't execute 
transactionManager.maybeAddPartitionToTransaction(tp0). This should result in 
an error on drain.
         appendToAccumulator(tp0);
         Node node1 = new Node(0, "localhost", 1111);
-        PartitionInfo part1 = new PartitionInfo(topic, 0, node1, null, null);
-
-        Cluster cluster = new Cluster(null, Collections.singletonList(node1), 
Collections.singletonList(part1),
-                Collections.emptySet(), Collections.emptySet());
-        Metadata metadataMock = setupMetadata(cluster);
+        PartitionMetadata part1Metadata = new PartitionMetadata(Errors.NONE, 
tp0, Optional.of(node1.id()), Optional.empty(), null, null, null);
+        MetadataSnapshot metadataCache = new MetadataSnapshot(null, 
Collections.singletonMap(node1.id(), node1), Arrays.asList(part1Metadata), 
Collections.emptySet(), Collections.emptySet(), Collections.emptySet(), null, 
Collections.emptyMap());
 
         Set<Node> nodes = new HashSet<>();
         nodes.add(node1);
-        Map<Integer, List<ProducerBatch>> drainedBatches = 
accumulator.drain(metadataMock, nodes, Integer.MAX_VALUE,
+        Map<Integer, List<ProducerBatch>> drainedBatches = 
accumulator.drain(metadataCache, nodes, Integer.MAX_VALUE,
                 time.milliseconds());
 
         // We shouldn't drain batches which haven't been added to the 
transaction yet.
diff --git a/clients/src/test/java/org/apache/kafka/test/TestUtils.java 
b/clients/src/test/java/org/apache/kafka/test/TestUtils.java
index c8a6db6f6ca..db86558f7d0 100644
--- a/clients/src/test/java/org/apache/kafka/test/TestUtils.java
+++ b/clients/src/test/java/org/apache/kafka/test/TestUtils.java
@@ -16,6 +16,7 @@
  */
 package org.apache.kafka.test;
 
+import org.apache.kafka.clients.MetadataSnapshot;
 import org.apache.kafka.clients.consumer.ConsumerConfig;
 import org.apache.kafka.clients.producer.ProducerConfig;
 import org.apache.kafka.common.Cluster;
@@ -29,10 +30,12 @@ import 
org.apache.kafka.common.message.ApiVersionsResponseData;
 import org.apache.kafka.common.network.NetworkReceive;
 import org.apache.kafka.common.network.Send;
 import org.apache.kafka.common.protocol.ApiKeys;
+import org.apache.kafka.common.protocol.Errors;
 import org.apache.kafka.common.record.RecordVersion;
 import org.apache.kafka.common.record.UnalignedRecords;
 import org.apache.kafka.common.requests.ApiVersionsResponse;
 import org.apache.kafka.common.requests.ByteBufferChannel;
+import org.apache.kafka.common.requests.MetadataResponse.PartitionMetadata;
 import org.apache.kafka.common.requests.RequestHeader;
 import org.apache.kafka.common.utils.Exit;
 import org.apache.kafka.common.utils.Utils;
@@ -123,6 +126,42 @@ public class TestUtils {
         return clusterWith(nodes, Collections.singletonMap(topic, partitions));
     }
 
+    /**
+     * Test utility function to get MetadataSnapshot with configured nodes and 
partitions.
+     * @param nodes number of nodes in the cluster
+     * @param topicPartitionCounts map of topic -> # of partitions
+     * @return a MetadataSnapshot with number of nodes, partitions as per the 
input.
+     */
+
+    public static MetadataSnapshot metadataSnapshotWith(final int nodes, final 
Map<String, Integer> topicPartitionCounts) {
+        final Node[] ns = new Node[nodes];
+        Map<Integer, Node> nodesById = new HashMap<>();
+        for (int i = 0; i < nodes; i++) {
+            ns[i] = new Node(i, "localhost", 1969);
+            nodesById.put(ns[i].id(), ns[i]);
+        }
+        final List<PartitionMetadata> partsMetadatas = new ArrayList<>();
+        for (final Map.Entry<String, Integer> topicPartition : 
topicPartitionCounts.entrySet()) {
+            final String topic = topicPartition.getKey();
+            final int partitions = topicPartition.getValue();
+            for (int i = 0; i < partitions; i++) {
+                TopicPartition tp = new TopicPartition(topic, partitions);
+                Node node = ns[i % ns.length];
+                partsMetadatas.add(new PartitionMetadata(Errors.NONE, tp, 
Optional.of(node.id()), Optional.empty(), null, null, null));
+            }
+        }
+        return new MetadataSnapshot("kafka-cluster", nodesById, 
partsMetadatas, Collections.emptySet(), Collections.emptySet(), 
Collections.emptySet(), null, Collections.emptyMap());
+    }
+
+    /**
+     * Test utility function to get MetadataSnapshot of cluster with 
configured, and 0 partitions.
+     * @param nodes number of nodes in the cluster.
+     * @return a MetadataSnapshot of cluster with number of nodes in the input.
+     */
+    public static MetadataSnapshot metadataSnapshotWith(int nodes) {
+        return metadataSnapshotWith(nodes, new HashMap<>());
+    }
+
     /**
      * Generate an array of random bytes
      *

Reply via email to