Repository: kafka
Updated Branches:
  refs/heads/0.10.2 6493e5d8d -> d8f9e18aa


MINOR: Refactor partition lag metric for cleaner encapsulation

Author: Jason Gustafson <ja...@confluent.io>

Reviewers: Jiangjie Qin <becket....@gmail.com>, Guozhang Wang 
<wangg...@gmail.com>, Ismael Juma <ism...@juma.me.uk>

Closes #2416 from hachikuji/refactor-partition-lag-cleanup

(cherry picked from commit 135488352cbe646334a274418d7014a91e92057f)
Signed-off-by: Ismael Juma <ism...@juma.me.uk>


Project: http://git-wip-us.apache.org/repos/asf/kafka/repo
Commit: http://git-wip-us.apache.org/repos/asf/kafka/commit/d8f9e18a
Tree: http://git-wip-us.apache.org/repos/asf/kafka/tree/d8f9e18a
Diff: http://git-wip-us.apache.org/repos/asf/kafka/diff/d8f9e18a

Branch: refs/heads/0.10.2
Commit: d8f9e18aa0f16b97e71719ca5983f8ff39134a4a
Parents: 6493e5d
Author: Jason Gustafson <ja...@confluent.io>
Authored: Thu Jan 26 02:53:59 2017 +0000
Committer: Ismael Juma <ism...@juma.me.uk>
Committed: Thu Jan 26 02:54:21 2017 +0000

----------------------------------------------------------------------
 .../consumer/ConsumerRebalanceListener.java     |   4 +-
 .../kafka/clients/consumer/KafkaConsumer.java   |   2 +-
 .../kafka/clients/consumer/MockConsumer.java    |   3 +-
 .../consumer/internals/ConsumerCoordinator.java |   2 +-
 .../clients/consumer/internals/Fetcher.java     | 107 ++++++++++---------
 .../consumer/internals/SubscriptionState.java   |  62 ++++++-----
 .../clients/consumer/KafkaConsumerTest.java     |   2 +-
 .../internals/ConsumerCoordinatorTest.java      |   2 +-
 .../clients/consumer/internals/FetcherTest.java |  16 ++-
 .../internals/SubscriptionStateTest.java        |  44 ++++++--
 .../kafka/api/BaseConsumerTest.scala            | 101 +----------------
 .../kafka/api/PlaintextConsumerTest.scala       | 102 +++++++++++++++++-
 12 files changed, 250 insertions(+), 197 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/kafka/blob/d8f9e18a/clients/src/main/java/org/apache/kafka/clients/consumer/ConsumerRebalanceListener.java
----------------------------------------------------------------------
diff --git 
a/clients/src/main/java/org/apache/kafka/clients/consumer/ConsumerRebalanceListener.java
 
b/clients/src/main/java/org/apache/kafka/clients/consumer/ConsumerRebalanceListener.java
index 938d22b..a4265ab 100644
--- 
a/clients/src/main/java/org/apache/kafka/clients/consumer/ConsumerRebalanceListener.java
+++ 
b/clients/src/main/java/org/apache/kafka/clients/consumer/ConsumerRebalanceListener.java
@@ -86,7 +86,7 @@ public interface ConsumerRebalanceListener {
      *
      * @param partitions The list of partitions that were assigned to the 
consumer on the last rebalance
      */
-    public void onPartitionsRevoked(Collection<TopicPartition> partitions);
+    void onPartitionsRevoked(Collection<TopicPartition> partitions);
 
     /**
      * A callback method the user can implement to provide handling of 
customized offsets on completion of a successful
@@ -100,5 +100,5 @@ public interface ConsumerRebalanceListener {
      * @param partitions The list of partitions that are now assigned to the 
consumer (may include partitions previously
      *            assigned to the consumer)
      */
-    public void onPartitionsAssigned(Collection<TopicPartition> partitions);
+    void onPartitionsAssigned(Collection<TopicPartition> partitions);
 }

http://git-wip-us.apache.org/repos/asf/kafka/blob/d8f9e18a/clients/src/main/java/org/apache/kafka/clients/consumer/KafkaConsumer.java
----------------------------------------------------------------------
diff --git 
a/clients/src/main/java/org/apache/kafka/clients/consumer/KafkaConsumer.java 
b/clients/src/main/java/org/apache/kafka/clients/consumer/KafkaConsumer.java
index be212db..2936f0f 100644
--- a/clients/src/main/java/org/apache/kafka/clients/consumer/KafkaConsumer.java
+++ b/clients/src/main/java/org/apache/kafka/clients/consumer/KafkaConsumer.java
@@ -663,7 +663,7 @@ public class KafkaConsumer<K, V> implements Consumer<K, V> {
             this.client = new ConsumerNetworkClient(netClient, metadata, time, 
retryBackoffMs,
                     config.getInt(ConsumerConfig.REQUEST_TIMEOUT_MS_CONFIG));
             OffsetResetStrategy offsetResetStrategy = 
OffsetResetStrategy.valueOf(config.getString(ConsumerConfig.AUTO_OFFSET_RESET_CONFIG).toUpperCase(Locale.ROOT));
-            this.subscriptions = new SubscriptionState(offsetResetStrategy, 
metrics);
+            this.subscriptions = new SubscriptionState(offsetResetStrategy);
             List<PartitionAssignor> assignors = config.getConfiguredInstances(
                     ConsumerConfig.PARTITION_ASSIGNMENT_STRATEGY_CONFIG,
                     PartitionAssignor.class);

http://git-wip-us.apache.org/repos/asf/kafka/blob/d8f9e18a/clients/src/main/java/org/apache/kafka/clients/consumer/MockConsumer.java
----------------------------------------------------------------------
diff --git 
a/clients/src/main/java/org/apache/kafka/clients/consumer/MockConsumer.java 
b/clients/src/main/java/org/apache/kafka/clients/consumer/MockConsumer.java
index 95e3830..a88f432 100644
--- a/clients/src/main/java/org/apache/kafka/clients/consumer/MockConsumer.java
+++ b/clients/src/main/java/org/apache/kafka/clients/consumer/MockConsumer.java
@@ -34,7 +34,6 @@ import java.util.Set;
 import java.util.concurrent.TimeUnit;
 import java.util.concurrent.atomic.AtomicBoolean;
 import java.util.regex.Pattern;
-import org.apache.kafka.common.metrics.Metrics;
 
 
 /**
@@ -59,7 +58,7 @@ public class MockConsumer<K, V> implements Consumer<K, V> {
     private AtomicBoolean wakeup;
 
     public MockConsumer(OffsetResetStrategy offsetResetStrategy) {
-        this.subscriptions = new SubscriptionState(offsetResetStrategy, new 
Metrics());
+        this.subscriptions = new SubscriptionState(offsetResetStrategy);
         this.partitions = new HashMap<>();
         this.records = new HashMap<>();
         this.paused = new HashSet<>();

http://git-wip-us.apache.org/repos/asf/kafka/blob/d8f9e18a/clients/src/main/java/org/apache/kafka/clients/consumer/internals/ConsumerCoordinator.java
----------------------------------------------------------------------
diff --git 
a/clients/src/main/java/org/apache/kafka/clients/consumer/internals/ConsumerCoordinator.java
 
b/clients/src/main/java/org/apache/kafka/clients/consumer/internals/ConsumerCoordinator.java
index 4c54a8f..03b767a 100644
--- 
a/clients/src/main/java/org/apache/kafka/clients/consumer/internals/ConsumerCoordinator.java
+++ 
b/clients/src/main/java/org/apache/kafka/clients/consumer/internals/ConsumerCoordinator.java
@@ -155,7 +155,7 @@ public final class ConsumerCoordinator extends 
AbstractCoordinator {
         final Set<String> topicsToSubscribe = new HashSet<>();
 
         for (String topic : cluster.topics())
-            if (subscriptions.getSubscribedPattern().matcher(topic).matches() 
&&
+            if (subscriptions.subscribedPattern().matcher(topic).matches() &&
                     !(excludeInternalTopics && 
cluster.internalTopics().contains(topic)))
                 topicsToSubscribe.add(topic);
 

http://git-wip-us.apache.org/repos/asf/kafka/blob/d8f9e18a/clients/src/main/java/org/apache/kafka/clients/consumer/internals/Fetcher.java
----------------------------------------------------------------------
diff --git 
a/clients/src/main/java/org/apache/kafka/clients/consumer/internals/Fetcher.java
 
b/clients/src/main/java/org/apache/kafka/clients/consumer/internals/Fetcher.java
index cd7f307..6bd454e 100644
--- 
a/clients/src/main/java/org/apache/kafka/clients/consumer/internals/Fetcher.java
+++ 
b/clients/src/main/java/org/apache/kafka/clients/consumer/internals/Fetcher.java
@@ -72,7 +72,7 @@ import java.util.concurrent.atomic.AtomicInteger;
 /**
  * This class manage the fetching process with the brokers.
  */
-public class Fetcher<K, V> {
+public class Fetcher<K, V> implements SubscriptionState.Listener {
 
     private static final Logger log = LoggerFactory.getLogger(Fetcher.class);
 
@@ -124,6 +124,8 @@ public class Fetcher<K, V> {
         this.completedFetches = new ConcurrentLinkedQueue<>();
         this.sensors = new FetchManagerMetrics(metrics, metricGrpPrefix);
         this.retryBackoffMs = retryBackoffMs;
+
+        subscriptions.addListener(this);
     }
 
     /**
@@ -453,10 +455,9 @@ public class Fetcher<K, V> {
                 if (completedFetch == null)
                     break;
 
-                nextInLineRecords = parseFetchedData(completedFetch);
+                nextInLineRecords = parseCompletedFetch(completedFetch);
             } else {
                 TopicPartition partition = nextInLineRecords.partition;
-
                 List<ConsumerRecord<K, V>> records = 
drainRecords(nextInLineRecords, recordsRemaining);
                 if (!records.isEmpty()) {
                     List<ConsumerRecord<K, V>> currentRecords = 
drained.get(partition);
@@ -480,9 +481,6 @@ public class Fetcher<K, V> {
     }
 
     private List<ConsumerRecord<K, V>> drainRecords(PartitionRecords<K, V> 
partitionRecords, int maxRecords) {
-        if (partitionRecords.isDrained())
-            return Collections.emptyList();
-
         if (!subscriptions.isAssigned(partitionRecords.partition)) {
             // this can happen when a rebalance happened before fetched 
records are returned to the consumer's poll call
             log.debug("Not returning fetched records for partition {} since it 
is no longer assigned", partitionRecords.partition);
@@ -493,19 +491,19 @@ public class Fetcher<K, V> {
                 // this can happen when a partition is paused before fetched 
records are returned to the consumer's poll call
                 log.debug("Not returning fetched records for assigned 
partition {} since it is no longer fetchable", partitionRecords.partition);
             } else if (partitionRecords.fetchOffset == position) {
-                // we are ensured to have at least one record since we already 
checked for emptiness
                 List<ConsumerRecord<K, V>> partRecords = 
partitionRecords.drainRecords(maxRecords);
-                long nextOffset = partRecords.get(partRecords.size() - 
1).offset() + 1;
+                if (!partRecords.isEmpty()) {
+                    long nextOffset = partRecords.get(partRecords.size() - 
1).offset() + 1;
+                    log.trace("Returning fetched records at offset {} for 
assigned partition {} and update " +
+                            "position to {}", position, 
partitionRecords.partition, nextOffset);
 
-                log.trace("Returning fetched records at offset {} for assigned 
partition {} and update " +
-                        "position to {}", position, 
partitionRecords.partition, nextOffset);
+                    subscriptions.position(partitionRecords.partition, 
nextOffset);
+                }
 
-                subscriptions.position(partitionRecords.partition, nextOffset);
                 Long partitionLag = 
subscriptions.partitionLag(partitionRecords.partition);
-                if (partitionLag != null) {
-                    this.sensors.recordsFetchLag.record(partitionLag);
-                    
this.sensors.recordPartitionFetchLag(partitionRecords.partition, partitionLag);
-                }
+                if (partitionLag != null)
+                    
this.sensors.recordPartitionLag(partitionRecords.partition, partitionLag);
+
                 return partRecords;
             } else {
                 // these records aren't next in line based on the last 
consumed position, ignore them
@@ -731,7 +729,7 @@ public class Fetcher<K, V> {
     /**
      * The callback for fetch completion
      */
-    private PartitionRecords<K, V> parseFetchedData(CompletedFetch 
completedFetch) {
+    private PartitionRecords<K, V> parseCompletedFetch(CompletedFetch 
completedFetch) {
         TopicPartition tp = completedFetch.partition;
         FetchResponse.PartitionData partition = completedFetch.partitionData;
         long fetchOffset = completedFetch.fetchedOffset;
@@ -766,26 +764,13 @@ public class Fetcher<K, V> {
 
                 recordsCount = parsed.size();
 
-                if (!parsed.isEmpty()) {
-                    log.trace("Adding fetched record for partition {} with 
offset {} to buffered record list", tp, position);
-                    parsedRecords = new PartitionRecords<>(fetchOffset, tp, 
parsed);
-                }
+                log.trace("Adding fetched record for partition {} with offset 
{} to buffered record list", tp, position);
+                parsedRecords = new PartitionRecords<>(fetchOffset, tp, 
parsed);
 
                 if (partition.highWatermark >= 0) {
                     log.trace("Received {} records in fetch response for 
partition {} with offset {}", parsed.size(), tp, position);
-                    Long partitionLag = subscriptions.partitionLag(tp);
                     subscriptions.updateHighWatermark(tp, 
partition.highWatermark);
-                    // If the partition lag is null, that means this is the 
first fetch response for this partition.
-                    // We update the lag here to create the lag metric. This 
is to handle the case that there is no
-                    // message consumed by the end user from this partition. 
If there are messages returned from the
-                    // partition, the lag will be updated when those messages 
are consumed by the end user.
-                    if (partitionLag == null) {
-                        partitionLag = subscriptions.partitionLag(tp);
-                        this.sensors.recordsFetchLag.record(partitionLag);
-                        this.sensors.recordPartitionFetchLag(tp, partitionLag);
-                    }
                 }
-
             } else if (error == Errors.NOT_LEADER_FOR_PARTITION) {
                 log.debug("Error in fetch for partition {}: {}", tp, 
error.exceptionName());
                 this.metadata.requestUpdate();
@@ -860,20 +845,25 @@ public class Fetcher<K, V> {
         }
     }
 
+    @Override
+    public void onAssignment(Set<TopicPartition> assignment) {
+        sensors.updatePartitionLagSensors(assignment);
+    }
+
     private static class PartitionRecords<K, V> {
         private long fetchOffset;
         private TopicPartition partition;
         private List<ConsumerRecord<K, V>> records;
         private int position = 0;
 
-        public PartitionRecords(long fetchOffset, TopicPartition partition, 
List<ConsumerRecord<K, V>> records) {
+        private PartitionRecords(long fetchOffset, TopicPartition partition, 
List<ConsumerRecord<K, V>> records) {
             this.fetchOffset = fetchOffset;
             this.partition = partition;
             this.records = records;
         }
 
         private boolean isDrained() {
-            return records == null || position >= records.size();
+            return records == null;
         }
 
         private void drain() {
@@ -881,8 +871,10 @@ public class Fetcher<K, V> {
         }
 
         private List<ConsumerRecord<K, V>> drainRecords(int n) {
-            if (isDrained())
+            if (isDrained() || position >= records.size()) {
+                drain();
                 return Collections.emptyList();
+            }
 
             // using a sublist avoids a potentially expensive list copy 
(depending on the size of the records
             // and the maximum we can return from poll). The cost is that we 
cannot mutate the returned sublist.
@@ -903,10 +895,10 @@ public class Fetcher<K, V> {
         private final FetchResponse.PartitionData partitionData;
         private final FetchResponseMetricAggregator metricAggregator;
 
-        public CompletedFetch(TopicPartition partition,
-                              long fetchedOffset,
-                              FetchResponse.PartitionData partitionData,
-                              FetchResponseMetricAggregator metricAggregator) {
+        private CompletedFetch(TopicPartition partition,
+                               long fetchedOffset,
+                               FetchResponse.PartitionData partitionData,
+                               FetchResponseMetricAggregator metricAggregator) 
{
             this.partition = partition;
             this.fetchedOffset = fetchedOffset;
             this.partitionData = partitionData;
@@ -974,16 +966,17 @@ public class Fetcher<K, V> {
     }
 
     private static class FetchManagerMetrics {
-        public final Metrics metrics;
-        public final String metricGrpName;
+        private final Metrics metrics;
+        private final String metricGrpName;
+        private final Sensor bytesFetched;
+        private final Sensor recordsFetched;
+        private final Sensor fetchLatency;
+        private final Sensor recordsFetchLag;
+        private final Sensor fetchThrottleTimeSensor;
 
-        public final Sensor bytesFetched;
-        public final Sensor recordsFetched;
-        public final Sensor fetchLatency;
-        public final Sensor recordsFetchLag;
-        public final Sensor fetchThrottleTimeSensor;
+        private Set<TopicPartition> assignedPartitions;
 
-        public FetchManagerMetrics(Metrics metrics, String metricGrpPrefix) {
+        private FetchManagerMetrics(Metrics metrics, String metricGrpPrefix) {
             this.metrics = metrics;
             this.metricGrpName = metricGrpPrefix + "-fetch-manager-metrics";
 
@@ -1032,7 +1025,7 @@ public class Fetcher<K, V> {
                                                          "The maximum throttle 
time in ms"), new Max());
         }
 
-        public void recordTopicFetchMetrics(String topic, int bytes, int 
records) {
+        private void recordTopicFetchMetrics(String topic, int bytes, int 
records) {
             // record bytes fetched
             String name = "topic." + topic + ".bytes-fetched";
             Sensor bytesFetched = this.metrics.getSensor(name);
@@ -1075,8 +1068,20 @@ public class Fetcher<K, V> {
             recordsFetched.record(records);
         }
 
-        public void recordPartitionFetchLag(TopicPartition tp, long lag) {
-            String name = tp + ".records-lag";
+        private void updatePartitionLagSensors(Set<TopicPartition> 
assignedPartitions) {
+            if (this.assignedPartitions != null) {
+                for (TopicPartition tp : this.assignedPartitions) {
+                    if (!assignedPartitions.contains(tp))
+                        metrics.removeSensor(partitionLagMetricName(tp));
+                }
+            }
+            this.assignedPartitions = assignedPartitions;
+        }
+
+        private void recordPartitionLag(TopicPartition tp, long lag) {
+            this.recordsFetchLag.record(lag);
+
+            String name = partitionLagMetricName(tp);
             Sensor recordsLag = this.metrics.getSensor(name);
             if (recordsLag == null) {
                 recordsLag = this.metrics.sensor(name);
@@ -1091,6 +1096,10 @@ public class Fetcher<K, V> {
             }
             recordsLag.record(lag);
         }
+
+        private static String partitionLagMetricName(TopicPartition tp) {
+            return tp + ".records-lag";
+        }
     }
 
 }

http://git-wip-us.apache.org/repos/asf/kafka/blob/d8f9e18a/clients/src/main/java/org/apache/kafka/clients/consumer/internals/SubscriptionState.java
----------------------------------------------------------------------
diff --git 
a/clients/src/main/java/org/apache/kafka/clients/consumer/internals/SubscriptionState.java
 
b/clients/src/main/java/org/apache/kafka/clients/consumer/internals/SubscriptionState.java
index 1a2a7ee..25995fb 100644
--- 
a/clients/src/main/java/org/apache/kafka/clients/consumer/internals/SubscriptionState.java
+++ 
b/clients/src/main/java/org/apache/kafka/clients/consumer/internals/SubscriptionState.java
@@ -17,7 +17,6 @@ import org.apache.kafka.clients.consumer.OffsetAndMetadata;
 import org.apache.kafka.clients.consumer.OffsetResetStrategy;
 import org.apache.kafka.common.TopicPartition;
 import org.apache.kafka.common.internals.PartitionStates;
-import org.apache.kafka.common.metrics.Metrics;
 
 import java.util.ArrayList;
 import java.util.Collection;
@@ -78,19 +77,19 @@ public class SubscriptionState {
     /* Default offset reset strategy */
     private final OffsetResetStrategy defaultResetStrategy;
 
-    /* Listener to be invoked when assignment changes */
+    /* User-provided listener to be invoked when assignment changes */
     private ConsumerRebalanceListener listener;
 
-    private final Metrics metrics;
+    /* Listeners provide a hook for internal state cleanup (e.g. metrics) on 
assignment changes */
+    private List<Listener> listeners = new ArrayList<>();
 
-    public SubscriptionState(OffsetResetStrategy defaultResetStrategy, Metrics 
metrics) {
+    public SubscriptionState(OffsetResetStrategy defaultResetStrategy) {
         this.defaultResetStrategy = defaultResetStrategy;
         this.subscription = Collections.emptySet();
         this.assignment = new PartitionStates<>();
         this.groupSubscription = new HashSet<>();
         this.needsFetchCommittedOffsets = true; // initialize to true for the 
consumers to fetch offset upon starting up
         this.subscribedPattern = null;
-        this.metrics = metrics;
         this.subscriptionType = SubscriptionType.NONE;
     }
 
@@ -160,7 +159,8 @@ public class SubscriptionState {
         setSubscriptionType(SubscriptionType.USER_ASSIGNED);
 
         if (!this.assignment.partitionSet().equals(partitions)) {
-            removeAllLagSensors(partitions);
+            fireOnAssignment(partitions);
+
             Map<TopicPartition, TopicPartitionState> partitionToState = new 
HashMap<>();
             for (TopicPartition partition : partitions) {
                 TopicPartitionState state = assignment.stateValue(partition);
@@ -180,8 +180,9 @@ public class SubscriptionState {
     public void assignFromSubscribed(Collection<TopicPartition> assignments) {
         if (!this.partitionsAutoAssigned())
             throw new IllegalArgumentException("Attempt to dynamically assign 
partitions while manual assignment in use");
-        Set<TopicPartition> newAssignment = new HashSet<>(assignments);
-        removeAllLagSensors(newAssignment);
+
+        Map<TopicPartition, TopicPartitionState> assignedPartitionStates = 
partitionToStateMap(assignments);
+        fireOnAssignment(assignedPartitionStates.keySet());
 
         if (this.subscribedPattern != null) {
             for (TopicPartition tp : assignments) {
@@ -195,24 +196,10 @@ public class SubscriptionState {
         }
 
         // after rebalancing, we always reinitialize the assignment value
-        this.assignment.set(partitionToStateMap(assignments));
+        this.assignment.set(assignedPartitionStates);
         this.needsFetchCommittedOffsets = true;
     }
 
-    private void removeAllLagSensors(Set<TopicPartition> preservedPartitions) {
-        for (TopicPartition tp : assignment.partitionSet()) {
-            if (!preservedPartitions.contains(tp))
-                metrics.removeSensor(tp + ".records-lag");
-        }
-    }
-
-    private Map<TopicPartition, TopicPartitionState> 
partitionToStateMap(Collection<TopicPartition> assignments) {
-        Map<TopicPartition, TopicPartitionState> map = new 
HashMap<>(assignments.size());
-        for (TopicPartition tp : assignments)
-            map.put(tp, new TopicPartitionState());
-        return map;
-    }
-
     public void subscribe(Pattern pattern, ConsumerRebalanceListener listener) 
{
         if (listener == null)
             throw new IllegalArgumentException("RebalanceListener cannot be 
null");
@@ -236,9 +223,10 @@ public class SubscriptionState {
         this.assignment.clear();
         this.subscribedPattern = null;
         this.subscriptionType = SubscriptionType.NONE;
+        fireOnAssignment(Collections.<TopicPartition>emptySet());
     }
 
-    public Pattern getSubscribedPattern() {
+    public Pattern subscribedPattern() {
         return this.subscribedPattern;
     }
 
@@ -416,6 +404,22 @@ public class SubscriptionState {
         return listener;
     }
 
+    public void addListener(Listener listener) {
+        listeners.add(listener);
+    }
+
+    public void fireOnAssignment(Set<TopicPartition> assignment) {
+        for (Listener listener : listeners)
+            listener.onAssignment(assignment);
+    }
+
+    private static Map<TopicPartition, TopicPartitionState> 
partitionToStateMap(Collection<TopicPartition> assignments) {
+        Map<TopicPartition, TopicPartitionState> map = new 
HashMap<>(assignments.size());
+        for (TopicPartition tp : assignments)
+            map.put(tp, new TopicPartitionState());
+        return map;
+    }
+
     private static class TopicPartitionState {
         private Long position; // last consumed position
         private Long highWatermark; // the high watermark from last fetch
@@ -473,4 +477,14 @@ public class SubscriptionState {
 
     }
 
+    public interface Listener {
+        /**
+         * Fired after a new assignment is received (after a group rebalance 
or when the user manually changes the
+         * assignment).
+         *
+         * @param assignment The topic partitions assigned to the consumer
+         */
+        void onAssignment(Set<TopicPartition> assignment);
+    }
+
 }

http://git-wip-us.apache.org/repos/asf/kafka/blob/d8f9e18a/clients/src/test/java/org/apache/kafka/clients/consumer/KafkaConsumerTest.java
----------------------------------------------------------------------
diff --git 
a/clients/src/test/java/org/apache/kafka/clients/consumer/KafkaConsumerTest.java
 
b/clients/src/test/java/org/apache/kafka/clients/consumer/KafkaConsumerTest.java
index ac88ce9..8346e93 100644
--- 
a/clients/src/test/java/org/apache/kafka/clients/consumer/KafkaConsumerTest.java
+++ 
b/clients/src/test/java/org/apache/kafka/clients/consumer/KafkaConsumerTest.java
@@ -1491,7 +1491,7 @@ public class KafkaConsumerTest {
         ConsumerInterceptors<String, String> interceptors = null;
 
         Metrics metrics = new Metrics();
-        SubscriptionState subscriptions = new 
SubscriptionState(autoResetStrategy, metrics);
+        SubscriptionState subscriptions = new 
SubscriptionState(autoResetStrategy);
         ConsumerNetworkClient consumerClient = new 
ConsumerNetworkClient(client, metadata, time, retryBackoffMs, requestTimeoutMs);
         ConsumerCoordinator consumerCoordinator = new ConsumerCoordinator(
                 consumerClient,

http://git-wip-us.apache.org/repos/asf/kafka/blob/d8f9e18a/clients/src/test/java/org/apache/kafka/clients/consumer/internals/ConsumerCoordinatorTest.java
----------------------------------------------------------------------
diff --git 
a/clients/src/test/java/org/apache/kafka/clients/consumer/internals/ConsumerCoordinatorTest.java
 
b/clients/src/test/java/org/apache/kafka/clients/consumer/internals/ConsumerCoordinatorTest.java
index 3c4dd2d..e13d49f 100644
--- 
a/clients/src/test/java/org/apache/kafka/clients/consumer/internals/ConsumerCoordinatorTest.java
+++ 
b/clients/src/test/java/org/apache/kafka/clients/consumer/internals/ConsumerCoordinatorTest.java
@@ -119,7 +119,7 @@ public class ConsumerCoordinatorTest {
     @Before
     public void setup() {
         this.time = new MockTime();
-        this.subscriptions = new 
SubscriptionState(OffsetResetStrategy.EARLIEST, metrics);
+        this.subscriptions = new 
SubscriptionState(OffsetResetStrategy.EARLIEST);
         this.metadata = new Metadata(0, Long.MAX_VALUE);
         this.metadata.update(cluster, time.milliseconds());
         this.client = new MockClient(time, metadata);

http://git-wip-us.apache.org/repos/asf/kafka/blob/d8f9e18a/clients/src/test/java/org/apache/kafka/clients/consumer/internals/FetcherTest.java
----------------------------------------------------------------------
diff --git 
a/clients/src/test/java/org/apache/kafka/clients/consumer/internals/FetcherTest.java
 
b/clients/src/test/java/org/apache/kafka/clients/consumer/internals/FetcherTest.java
index bdd56c3..210af6d 100644
--- 
a/clients/src/test/java/org/apache/kafka/clients/consumer/internals/FetcherTest.java
+++ 
b/clients/src/test/java/org/apache/kafka/clients/consumer/internals/FetcherTest.java
@@ -91,8 +91,8 @@ public class FetcherTest {
     private Cluster cluster = TestUtils.singletonCluster(topicName, 1);
     private Node node = cluster.nodes().get(0);
     private Metrics metrics = new Metrics(time);
-    private SubscriptionState subscriptions = new 
SubscriptionState(OffsetResetStrategy.EARLIEST, metrics);
-    private SubscriptionState subscriptionsNoAutoReset = new 
SubscriptionState(OffsetResetStrategy.NONE, metrics);
+    private SubscriptionState subscriptions = new 
SubscriptionState(OffsetResetStrategy.EARLIEST);
+    private SubscriptionState subscriptionsNoAutoReset = new 
SubscriptionState(OffsetResetStrategy.NONE);
     private static final double EPSILON = 0.0001;
     private ConsumerNetworkClient consumerClient = new 
ConsumerNetworkClient(client, metadata, time, 100, 1000);
 
@@ -704,8 +704,11 @@ public class FetcherTest {
         subscriptions.assignFromUser(singleton(tp));
         subscriptions.seek(tp, 0);
 
+        MetricName maxLagMetric = metrics.metricName("records-lag-max", 
metricGroup, "");
+        MetricName partitionLagMetric = metrics.metricName(tp + 
".records-lag", metricGroup, "");
+
         Map<MetricName, KafkaMetric> allMetrics = metrics.metrics();
-        KafkaMetric recordsFetchLagMax = 
allMetrics.get(metrics.metricName("records-lag-max", metricGroup, ""));
+        KafkaMetric recordsFetchLagMax = allMetrics.get(maxLagMetric);
 
         // recordsFetchLagMax should be initialized to negative infinity
         assertEquals(Double.NEGATIVE_INFINITY, recordsFetchLagMax.value(), 
EPSILON);
@@ -714,12 +717,19 @@ public class FetcherTest {
         fetchRecords(MemoryRecords.EMPTY, Errors.NONE.code(), 100L, 0);
         assertEquals(100, recordsFetchLagMax.value(), EPSILON);
 
+        KafkaMetric partitionLag = allMetrics.get(partitionLagMetric);
+        assertEquals(100, partitionLag.value(), EPSILON);
+
         // recordsFetchLagMax should be hw - offset of the last message after 
receiving a non-empty FetchResponse
         MemoryRecordsBuilder builder = 
MemoryRecords.builder(ByteBuffer.allocate(1024), CompressionType.NONE, 
TimestampType.CREATE_TIME);
         for (int v = 0; v < 3; v++)
             builder.appendWithOffset((long) v, Record.NO_TIMESTAMP, 
"key".getBytes(), String.format("value-%d", v).getBytes());
         fetchRecords(builder.build(), Errors.NONE.code(), 200L, 0);
         assertEquals(197, recordsFetchLagMax.value(), EPSILON);
+
+        // verify de-registration of partition lag
+        subscriptions.unsubscribe();
+        assertFalse(allMetrics.containsKey(partitionLagMetric));
     }
 
     private Map<TopicPartition, List<ConsumerRecord<byte[], byte[]>>> 
fetchRecords(MemoryRecords records, short error, long hw, int throttleTime) {

http://git-wip-us.apache.org/repos/asf/kafka/blob/d8f9e18a/clients/src/test/java/org/apache/kafka/clients/consumer/internals/SubscriptionStateTest.java
----------------------------------------------------------------------
diff --git 
a/clients/src/test/java/org/apache/kafka/clients/consumer/internals/SubscriptionStateTest.java
 
b/clients/src/test/java/org/apache/kafka/clients/consumer/internals/SubscriptionStateTest.java
index 55bf2a3..61a55e2 100644
--- 
a/clients/src/test/java/org/apache/kafka/clients/consumer/internals/SubscriptionStateTest.java
+++ 
b/clients/src/test/java/org/apache/kafka/clients/consumer/internals/SubscriptionStateTest.java
@@ -20,13 +20,15 @@ import 
org.apache.kafka.clients.consumer.ConsumerRebalanceListener;
 import org.apache.kafka.clients.consumer.OffsetAndMetadata;
 import org.apache.kafka.clients.consumer.OffsetResetStrategy;
 import org.apache.kafka.common.TopicPartition;
-import org.apache.kafka.common.metrics.Metrics;
+import org.apache.kafka.common.utils.Utils;
 import org.junit.Test;
 
 import java.util.Arrays;
 import java.util.Collection;
 import java.util.Collections;
 import java.util.HashSet;
+import java.util.Set;
+import java.util.concurrent.atomic.AtomicReference;
 import java.util.regex.Pattern;
 
 import static java.util.Collections.singleton;
@@ -36,7 +38,7 @@ import static org.junit.Assert.assertTrue;
 
 public class SubscriptionStateTest {
 
-    private final SubscriptionState state = new 
SubscriptionState(OffsetResetStrategy.EARLIEST, new Metrics());
+    private final SubscriptionState state = new 
SubscriptionState(OffsetResetStrategy.EARLIEST);
     private final String topic = "test";
     private final String topic1 = "test1";
     private final TopicPartition tp0 = new TopicPartition(topic, 0);
@@ -76,7 +78,7 @@ public class SubscriptionStateTest {
         // assigned partitions should remain unchanged
         assertTrue(state.assignedPartitions().isEmpty());
 
-        state.assignFromSubscribed(Collections.singletonList(t1p0));
+        state.assignFromSubscribed(singleton(t1p0));
         // assigned partitions should immediately change
         assertEquals(singleton(t1p0), state.assignedPartitions());
 
@@ -99,7 +101,7 @@ public class SubscriptionStateTest {
         // assigned partitions should remain unchanged
         assertTrue(state.assignedPartitions().isEmpty());
 
-        state.assignFromSubscribed(Collections.singletonList(tp1));
+        state.assignFromSubscribed(singleton(tp1));
         // assigned partitions should immediately change
         assertEquals(singleton(tp1), state.assignedPartitions());
         assertEquals(singleton(topic), state.subscription());
@@ -128,6 +130,28 @@ public class SubscriptionStateTest {
     }
 
     @Test
+    public void verifyAssignmentListener() {
+        final AtomicReference<Set<TopicPartition>> assignmentRef = new 
AtomicReference<>();
+        state.addListener(new SubscriptionState.Listener() {
+            @Override
+            public void onAssignment(Set<TopicPartition> assignment) {
+                assignmentRef.set(assignment);
+            }
+        });
+        Set<TopicPartition> userAssignment = Utils.mkSet(tp0, tp1);
+        state.assignFromUser(userAssignment);
+        assertEquals(userAssignment, assignmentRef.get());
+
+        state.unsubscribe();
+        assertEquals(Collections.emptySet(), assignmentRef.get());
+
+        Set<TopicPartition> autoAssignment = Utils.mkSet(t1p0);
+        state.subscribe(singleton(topic1), rebalanceListener);
+        state.assignFromSubscribed(autoAssignment);
+        assertEquals(autoAssignment, assignmentRef.get());
+    }
+
+    @Test
     public void partitionReset() {
         state.assignFromUser(singleton(tp0));
         state.seek(tp0, 5);
@@ -149,11 +173,11 @@ public class SubscriptionStateTest {
         assertEquals(1, state.subscription().size());
         assertTrue(state.assignedPartitions().isEmpty());
         assertTrue(state.partitionsAutoAssigned());
-        state.assignFromSubscribed(Collections.singletonList(tp0));
+        state.assignFromSubscribed(singleton(tp0));
         state.seek(tp0, 1);
         state.committed(tp0, new OffsetAndMetadata(1));
         assertAllPositions(tp0, 1L);
-        state.assignFromSubscribed(Collections.singletonList(tp1));
+        state.assignFromSubscribed(singleton(tp1));
         assertTrue(state.isAssigned(tp1));
         assertFalse(state.isAssigned(tp0));
         assertFalse(state.isFetchable(tp1));
@@ -183,7 +207,7 @@ public class SubscriptionStateTest {
     @Test(expected = IllegalStateException.class)
     public void invalidPositionUpdate() {
         state.subscribe(singleton(topic), rebalanceListener);
-        state.assignFromSubscribed(Collections.singletonList(tp0));
+        state.assignFromSubscribed(singleton(tp0));
         state.position(tp0, 0);
     }
 
@@ -233,9 +257,7 @@ public class SubscriptionStateTest {
     public void patternSubscription() {
         state.subscribe(Pattern.compile(".*"), rebalanceListener);
         state.subscribeFromPattern(new HashSet<>(Arrays.asList(topic, 
topic1)));
-
-        assertEquals(
-                "Expected subscribed topics count is incorrect", 2, 
state.subscription().size());
+        assertEquals("Expected subscribed topics count is incorrect", 2, 
state.subscription().size());
     }
 
     @Test
@@ -258,7 +280,7 @@ public class SubscriptionStateTest {
     public void unsubscription() {
         state.subscribe(Pattern.compile(".*"), rebalanceListener);
         state.subscribeFromPattern(new HashSet<>(Arrays.asList(topic, 
topic1)));
-        state.assignFromSubscribed(Collections.singletonList(tp1));
+        state.assignFromSubscribed(singleton(tp1));
         assertEquals(singleton(tp1), state.assignedPartitions());
 
         state.unsubscribe();

http://git-wip-us.apache.org/repos/asf/kafka/blob/d8f9e18a/core/src/test/scala/integration/kafka/api/BaseConsumerTest.scala
----------------------------------------------------------------------
diff --git a/core/src/test/scala/integration/kafka/api/BaseConsumerTest.scala 
b/core/src/test/scala/integration/kafka/api/BaseConsumerTest.scala
index 4bc1678..802bab8 100644
--- a/core/src/test/scala/integration/kafka/api/BaseConsumerTest.scala
+++ b/core/src/test/scala/integration/kafka/api/BaseConsumerTest.scala
@@ -13,13 +13,12 @@
 package kafka.api
 
 import java.util
-import java.util.Collections
 
 import org.apache.kafka.clients.consumer._
 import org.apache.kafka.clients.producer.{ProducerConfig, ProducerRecord}
 import org.apache.kafka.common.record.TimestampType
 import org.apache.kafka.common.serialization.ByteArrayDeserializer
-import org.apache.kafka.common.{MetricName, PartitionInfo, TopicPartition}
+import org.apache.kafka.common.{PartitionInfo, TopicPartition}
 import kafka.utils.{Logging, ShutdownableThread, TestUtils}
 import kafka.common.Topic
 import kafka.server.KafkaConfig
@@ -118,104 +117,6 @@ abstract class BaseConsumerTest extends 
IntegrationTestHarness with Logging {
     assertEquals(1, listener.callsToRevoked)
   }
 
-  @Test
-  def testPerPartitionLagMetricsCleanUpWithSubscribe() {
-    val numMessages = 1000
-    val topic2 = "topic2"
-    TestUtils.createTopic(this.zkUtils, topic2, 2, serverCount, this.servers)
-    // send some messages.
-    sendRecords(numMessages, tp)
-    // Test subscribe
-    // Create a consumer and consumer some messages.
-    consumerConfig.setProperty(ConsumerConfig.GROUP_ID_CONFIG, 
"testPerPartitionLagMetricsCleanUpWithSubscribe")
-    consumerConfig.setProperty(ConsumerConfig.CLIENT_ID_CONFIG, 
"testPerPartitionLagMetricsCleanUpWithSubscribe")
-    val consumer = new KafkaConsumer(this.consumerConfig, new 
ByteArrayDeserializer(), new ByteArrayDeserializer())
-    try {
-      val listener0 = new TestConsumerReassignmentListener
-      consumer.subscribe(List(topic, topic2).asJava, listener0)
-      var records: ConsumerRecords[Array[Byte], Array[Byte]] = 
ConsumerRecords.empty()
-      TestUtils.waitUntilTrue(() => {
-          records = consumer.poll(100)
-          !records.records(tp).isEmpty
-        }, "Consumer did not consume any message before timeout.")
-      assertEquals("should be assigned once", 1, listener0.callsToAssigned)
-      // Verify the metric exist.
-      val tags = Collections.singletonMap("client-id", 
"testPerPartitionLagMetricsCleanUpWithSubscribe")
-      val fetchLag0 = consumer.metrics.get(new MetricName(tp + ".records-lag", 
"consumer-fetch-manager-metrics", "", tags))
-      assertNotNull(fetchLag0)
-      val expectedLag = numMessages - records.count
-      assertEquals(s"The lag should be $expectedLag", expectedLag, 
fetchLag0.value, epsilon)
-
-      // Remove topic from subscription
-      consumer.subscribe(List(topic2).asJava, listener0)
-      TestUtils.waitUntilTrue(() => {
-        consumer.poll(100)
-        listener0.callsToAssigned >= 2
-      }, "Expected rebalance did not occur.")
-      // Verify the metric has gone
-      assertNull(consumer.metrics.get(new MetricName(tp + ".records-lag", 
"consumer-fetch-manager-metrics", "", tags)))
-      assertNull(consumer.metrics.get(new MetricName(tp2 + ".records-lag", 
"consumer-fetch-manager-metrics", "", tags)))
-    } finally {
-      consumer.close()
-    }
-  }
-
-  @Test
-  def testPerPartitionLagMetricsCleanUpWithAssign() {
-    val numMessages = 1000
-    // Test assign
-    // send some messages.
-    sendRecords(numMessages, tp)
-    sendRecords(numMessages, tp2)
-    consumerConfig.setProperty(ConsumerConfig.GROUP_ID_CONFIG, 
"testPerPartitionLagMetricsCleanUpWithAssign")
-    consumerConfig.setProperty(ConsumerConfig.CLIENT_ID_CONFIG, 
"testPerPartitionLagMetricsCleanUpWithAssign")
-    val consumer = new KafkaConsumer(this.consumerConfig, new 
ByteArrayDeserializer(), new ByteArrayDeserializer())
-    try {
-      consumer.assign(List(tp).asJava)
-      var records: ConsumerRecords[Array[Byte], Array[Byte]] = 
ConsumerRecords.empty()
-      TestUtils.waitUntilTrue(() => {
-          records = consumer.poll(100)
-          !records.records(tp).isEmpty
-        }, "Consumer did not consume any message before timeout.")
-      // Verify the metric exist.
-      val tags = Collections.singletonMap("client-id", 
"testPerPartitionLagMetricsCleanUpWithAssign")
-      val fetchLag = consumer.metrics.get(new MetricName(tp + ".records-lag", 
"consumer-fetch-manager-metrics", "", tags))
-      assertNotNull(fetchLag)
-      val expectedLag = numMessages - records.count
-      assertEquals(s"The lag should be $expectedLag", expectedLag, 
fetchLag.value, epsilon)
-
-      consumer.assign(List(tp2).asJava)
-      TestUtils.waitUntilTrue(() => !consumer.poll(100).isEmpty, "Consumer did 
not consume any message before timeout.")
-      assertNull(consumer.metrics.get(new MetricName(tp + ".records-lag", 
"consumer-fetch-manager-metrics", "", tags)))
-    } finally {
-      consumer.close()
-    }
-  }
-
-  @Test
-  def testPerPartitionLagWithMaxPollRecords() {
-    val numMessages = 1000
-    val maxPollRecords = 10
-    sendRecords(numMessages, tp)
-    consumerConfig.setProperty(ConsumerConfig.GROUP_ID_CONFIG, 
"testPerPartitionLagWithMaxPollRecords")
-    consumerConfig.setProperty(ConsumerConfig.CLIENT_ID_CONFIG, 
"testPerPartitionLagWithMaxPollRecords")
-    consumerConfig.setProperty(ConsumerConfig.MAX_POLL_RECORDS_CONFIG, 
maxPollRecords.toString)
-    val consumer = new KafkaConsumer(this.consumerConfig, new 
ByteArrayDeserializer(), new ByteArrayDeserializer())
-    consumer.assign(List(tp).asJava)
-    try {
-      var records: ConsumerRecords[Array[Byte], Array[Byte]] = 
ConsumerRecords.empty()
-      TestUtils.waitUntilTrue(() => {
-          records = consumer.poll(100)
-          !records.isEmpty
-        }, "Consumer did not consume any message before timeout.")
-      val tags = Collections.singletonMap("client-id", 
"testPerPartitionLagWithMaxPollRecords")
-      val lag = consumer.metrics.get(new MetricName(tp + ".records-lag", 
"consumer-fetch-manager-metrics", "", tags))
-      assertEquals(s"The lag should be ${numMessages - records.count}", 
numMessages - records.count, lag.value, epsilon)
-    } finally {
-      consumer.close()
-    }
-  }
-
   protected class TestConsumerReassignmentListener extends 
ConsumerRebalanceListener {
     var callsToAssigned = 0
     var callsToRevoked = 0

http://git-wip-us.apache.org/repos/asf/kafka/blob/d8f9e18a/core/src/test/scala/integration/kafka/api/PlaintextConsumerTest.scala
----------------------------------------------------------------------
diff --git 
a/core/src/test/scala/integration/kafka/api/PlaintextConsumerTest.scala 
b/core/src/test/scala/integration/kafka/api/PlaintextConsumerTest.scala
index 282d67c..4fa1462 100644
--- a/core/src/test/scala/integration/kafka/api/PlaintextConsumerTest.scala
+++ b/core/src/test/scala/integration/kafka/api/PlaintextConsumerTest.scala
@@ -12,7 +12,6 @@
   */
 package kafka.api
 
-
 import java.util
 import java.util.regex.Pattern
 import java.util.{Collections, Locale, Properties}
@@ -22,7 +21,7 @@ import kafka.server.KafkaConfig
 import kafka.utils.TestUtils
 import org.apache.kafka.clients.consumer._
 import org.apache.kafka.clients.producer.{KafkaProducer, ProducerConfig, 
ProducerRecord}
-import org.apache.kafka.common.TopicPartition
+import org.apache.kafka.common.{MetricName, TopicPartition}
 import org.apache.kafka.common.errors.InvalidTopicException
 import org.apache.kafka.common.record.{CompressionType, TimestampType}
 import org.apache.kafka.common.serialization.{ByteArrayDeserializer, 
ByteArraySerializer, StringDeserializer, StringSerializer}
@@ -1230,6 +1229,105 @@ class PlaintextConsumerTest extends BaseConsumerTest {
     assertEquals(500, consumer0.committed(tp2).offset)
   }
 
+
+  @Test
+  def testPerPartitionLagMetricsCleanUpWithSubscribe() {
+    val numMessages = 1000
+    val topic2 = "topic2"
+    TestUtils.createTopic(this.zkUtils, topic2, 2, serverCount, this.servers)
+    // send some messages.
+    sendRecords(numMessages, tp)
+    // Test subscribe
+    // Create a consumer and consumer some messages.
+    consumerConfig.setProperty(ConsumerConfig.GROUP_ID_CONFIG, 
"testPerPartitionLagMetricsCleanUpWithSubscribe")
+    consumerConfig.setProperty(ConsumerConfig.CLIENT_ID_CONFIG, 
"testPerPartitionLagMetricsCleanUpWithSubscribe")
+    val consumer = new KafkaConsumer(this.consumerConfig, new 
ByteArrayDeserializer(), new ByteArrayDeserializer())
+    try {
+      val listener0 = new TestConsumerReassignmentListener
+      consumer.subscribe(List(topic, topic2).asJava, listener0)
+      var records: ConsumerRecords[Array[Byte], Array[Byte]] = 
ConsumerRecords.empty()
+      TestUtils.waitUntilTrue(() => {
+        records = consumer.poll(100)
+        !records.records(tp).isEmpty
+      }, "Consumer did not consume any message before timeout.")
+      assertEquals("should be assigned once", 1, listener0.callsToAssigned)
+      // Verify the metric exist.
+      val tags = Collections.singletonMap("client-id", 
"testPerPartitionLagMetricsCleanUpWithSubscribe")
+      val fetchLag0 = consumer.metrics.get(new MetricName(tp + ".records-lag", 
"consumer-fetch-manager-metrics", "", tags))
+      assertNotNull(fetchLag0)
+      val expectedLag = numMessages - records.count
+      assertEquals(s"The lag should be $expectedLag", expectedLag, 
fetchLag0.value, epsilon)
+
+      // Remove topic from subscription
+      consumer.subscribe(List(topic2).asJava, listener0)
+      TestUtils.waitUntilTrue(() => {
+        consumer.poll(100)
+        listener0.callsToAssigned >= 2
+      }, "Expected rebalance did not occur.")
+      // Verify the metric has gone
+      assertNull(consumer.metrics.get(new MetricName(tp + ".records-lag", 
"consumer-fetch-manager-metrics", "", tags)))
+      assertNull(consumer.metrics.get(new MetricName(tp2 + ".records-lag", 
"consumer-fetch-manager-metrics", "", tags)))
+    } finally {
+      consumer.close()
+    }
+  }
+
+  @Test
+  def testPerPartitionLagMetricsCleanUpWithAssign() {
+    val numMessages = 1000
+    // Test assign
+    // send some messages.
+    sendRecords(numMessages, tp)
+    sendRecords(numMessages, tp2)
+    consumerConfig.setProperty(ConsumerConfig.GROUP_ID_CONFIG, 
"testPerPartitionLagMetricsCleanUpWithAssign")
+    consumerConfig.setProperty(ConsumerConfig.CLIENT_ID_CONFIG, 
"testPerPartitionLagMetricsCleanUpWithAssign")
+    val consumer = new KafkaConsumer(this.consumerConfig, new 
ByteArrayDeserializer(), new ByteArrayDeserializer())
+    try {
+      consumer.assign(List(tp).asJava)
+      var records: ConsumerRecords[Array[Byte], Array[Byte]] = 
ConsumerRecords.empty()
+      TestUtils.waitUntilTrue(() => {
+        records = consumer.poll(100)
+        !records.records(tp).isEmpty
+      }, "Consumer did not consume any message before timeout.")
+      // Verify the metric exist.
+      val tags = Collections.singletonMap("client-id", 
"testPerPartitionLagMetricsCleanUpWithAssign")
+      val fetchLag = consumer.metrics.get(new MetricName(tp + ".records-lag", 
"consumer-fetch-manager-metrics", "", tags))
+      assertNotNull(fetchLag)
+      val expectedLag = numMessages - records.count
+      assertEquals(s"The lag should be $expectedLag", expectedLag, 
fetchLag.value, epsilon)
+
+      consumer.assign(List(tp2).asJava)
+      TestUtils.waitUntilTrue(() => !consumer.poll(100).isEmpty, "Consumer did 
not consume any message before timeout.")
+      assertNull(consumer.metrics.get(new MetricName(tp + ".records-lag", 
"consumer-fetch-manager-metrics", "", tags)))
+    } finally {
+      consumer.close()
+    }
+  }
+
+  @Test
+  def testPerPartitionLagWithMaxPollRecords() {
+    val numMessages = 1000
+    val maxPollRecords = 10
+    sendRecords(numMessages, tp)
+    consumerConfig.setProperty(ConsumerConfig.GROUP_ID_CONFIG, 
"testPerPartitionLagWithMaxPollRecords")
+    consumerConfig.setProperty(ConsumerConfig.CLIENT_ID_CONFIG, 
"testPerPartitionLagWithMaxPollRecords")
+    consumerConfig.setProperty(ConsumerConfig.MAX_POLL_RECORDS_CONFIG, 
maxPollRecords.toString)
+    val consumer = new KafkaConsumer(this.consumerConfig, new 
ByteArrayDeserializer(), new ByteArrayDeserializer())
+    consumer.assign(List(tp).asJava)
+    try {
+      var records: ConsumerRecords[Array[Byte], Array[Byte]] = 
ConsumerRecords.empty()
+      TestUtils.waitUntilTrue(() => {
+        records = consumer.poll(100)
+        !records.isEmpty
+      }, "Consumer did not consume any message before timeout.")
+      val tags = Collections.singletonMap("client-id", 
"testPerPartitionLagWithMaxPollRecords")
+      val lag = consumer.metrics.get(new MetricName(tp + ".records-lag", 
"consumer-fetch-manager-metrics", "", tags))
+      assertEquals(s"The lag should be ${numMessages - records.count}", 
numMessages - records.count, lag.value, epsilon)
+    } finally {
+      consumer.close()
+    }
+  }
+
   def runMultiConsumerSessionTimeoutTest(closeConsumer: Boolean): Unit = {
     // use consumers defined in this class plus one additional consumer
     // Use topic defined in this class + one additional topic

Reply via email to