Repository: spark
Updated Branches:
  refs/heads/master bc6ea614a -> 79e06faa4


[SPARK-19185][DSTREAMS] Avoid concurrent use of cached consumers in 
CachedKafkaConsumer

## What changes were proposed in this pull request?

`CachedKafkaConsumer` in the project streaming-kafka-0-10 is designed to 
maintain a pool of KafkaConsumers that can be reused. However, it was built 
with the assumption there will be only one thread trying to read the same Kafka 
TopicPartition at the same time. This assumption is not true all the time and 
this can inadvertently lead to ConcurrentModificationException.

Here is a better way to design this. The consumer pool should be smart enough 
to avoid concurrent use of a cached consumer. If there is another request for 
the same TopicPartition as a currently in-use consumer, the pool should 
automatically return a fresh consumer.

- There are effectively two kinds of consumer that may be generated
  - Cached consumer - this should be returned to the pool at task end
  - Non-cached consumer - this should be closed at task end
- A trait called `KafkaDataConsumer` is introduced to hide this difference from 
the users of the consumer so that the client code does not have to reason about 
whether to stop and release. They simply call `val consumer = 
KafkaDataConsumer.acquire` and then `consumer.release`.
- If there is request for a consumer that is in-use, then a new consumer is 
generated.
- If there is request for a consumer which is a task reattempt, then already 
existing cached consumer will be invalidated and a new consumer is generated. 
This could fix potential issues if the source of the reattempt is a 
malfunctioning consumer.
- In addition, I renamed the `CachedKafkaConsumer` class to `KafkaDataConsumer` 
because is a misnomer given that what it returns may or may not be cached.

## How was this patch tested?

A new stress test that verifies it is safe to concurrently get consumers for 
the same TopicPartition from the consumer pool.

Author: Gabor Somogyi <gabor.g.somo...@gmail.com>

Closes #20997 from gaborgsomogyi/SPARK-19185.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/79e06faa
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/79e06faa
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/79e06faa

Branch: refs/heads/master
Commit: 79e06faa4ef6596c9e2d4be09c74b935064021bb
Parents: bc6ea61
Author: Gabor Somogyi <gabor.g.somo...@gmail.com>
Authored: Tue May 22 13:43:45 2018 -0700
Committer: Marcelo Vanzin <van...@cloudera.com>
Committed: Tue May 22 13:43:45 2018 -0700

----------------------------------------------------------------------
 .../spark/sql/kafka010/KafkaDataConsumer.scala  |   2 +-
 .../kafka010/CachedKafkaConsumer.scala          | 226 ------------
 .../streaming/kafka010/KafkaDataConsumer.scala  | 359 +++++++++++++++++++
 .../spark/streaming/kafka010/KafkaRDD.scala     |  20 +-
 .../kafka010/KafkaDataConsumerSuite.scala       | 131 +++++++
 5 files changed, 496 insertions(+), 242 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/79e06faa/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaDataConsumer.scala
----------------------------------------------------------------------
diff --git 
a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaDataConsumer.scala
 
b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaDataConsumer.scala
index 48508d0..941f0ab 100644
--- 
a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaDataConsumer.scala
+++ 
b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaDataConsumer.scala
@@ -395,7 +395,7 @@ private[kafka010] object KafkaDataConsumer extends Logging {
         // likely running on a beefy machine that can handle a large number of 
simultaneously
         // active consumers.
 
-        if (entry.getValue.inUse == false && this.size > capacity) {
+        if (!entry.getValue.inUse && this.size > capacity) {
           logWarning(
             s"KafkaConsumer cache hitting max capacity of $capacity, " +
               s"removing consumer for ${entry.getKey}")

http://git-wip-us.apache.org/repos/asf/spark/blob/79e06faa/external/kafka-0-10/src/main/scala/org/apache/spark/streaming/kafka010/CachedKafkaConsumer.scala
----------------------------------------------------------------------
diff --git 
a/external/kafka-0-10/src/main/scala/org/apache/spark/streaming/kafka010/CachedKafkaConsumer.scala
 
b/external/kafka-0-10/src/main/scala/org/apache/spark/streaming/kafka010/CachedKafkaConsumer.scala
deleted file mode 100644
index aeb8c1d..0000000
--- 
a/external/kafka-0-10/src/main/scala/org/apache/spark/streaming/kafka010/CachedKafkaConsumer.scala
+++ /dev/null
@@ -1,226 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements.  See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License.  You may obtain a copy of the License at
- *
- *    http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.streaming.kafka010
-
-import java.{ util => ju }
-
-import org.apache.kafka.clients.consumer.{ ConsumerConfig, ConsumerRecord, 
KafkaConsumer }
-import org.apache.kafka.common.{ KafkaException, TopicPartition }
-
-import org.apache.spark.internal.Logging
-
-/**
- * Consumer of single topicpartition, intended for cached reuse.
- * Underlying consumer is not threadsafe, so neither is this,
- * but processing the same topicpartition and group id in multiple threads is 
usually bad anyway.
- */
-private[kafka010]
-class CachedKafkaConsumer[K, V] private(
-  val groupId: String,
-  val topic: String,
-  val partition: Int,
-  val kafkaParams: ju.Map[String, Object]) extends Logging {
-
-  require(groupId == kafkaParams.get(ConsumerConfig.GROUP_ID_CONFIG),
-    "groupId used for cache key must match the groupId in kafkaParams")
-
-  val topicPartition = new TopicPartition(topic, partition)
-
-  protected val consumer = {
-    val c = new KafkaConsumer[K, V](kafkaParams)
-    val tps = new ju.ArrayList[TopicPartition]()
-    tps.add(topicPartition)
-    c.assign(tps)
-    c
-  }
-
-  // TODO if the buffer was kept around as a random-access structure,
-  // could possibly optimize re-calculating of an RDD in the same batch
-  protected var buffer = ju.Collections.emptyListIterator[ConsumerRecord[K, 
V]]()
-  protected var nextOffset = -2L
-
-  def close(): Unit = consumer.close()
-
-  /**
-   * Get the record for the given offset, waiting up to timeout ms if IO is 
necessary.
-   * Sequential forward access will use buffers, but random access will be 
horribly inefficient.
-   */
-  def get(offset: Long, timeout: Long): ConsumerRecord[K, V] = {
-    logDebug(s"Get $groupId $topic $partition nextOffset $nextOffset requested 
$offset")
-    if (offset != nextOffset) {
-      logInfo(s"Initial fetch for $groupId $topic $partition $offset")
-      seek(offset)
-      poll(timeout)
-    }
-
-    if (!buffer.hasNext()) { poll(timeout) }
-    require(buffer.hasNext(),
-      s"Failed to get records for $groupId $topic $partition $offset after 
polling for $timeout")
-    var record = buffer.next()
-
-    if (record.offset != offset) {
-      logInfo(s"Buffer miss for $groupId $topic $partition $offset")
-      seek(offset)
-      poll(timeout)
-      require(buffer.hasNext(),
-        s"Failed to get records for $groupId $topic $partition $offset after 
polling for $timeout")
-      record = buffer.next()
-      require(record.offset == offset,
-        s"Got wrong record for $groupId $topic $partition even after seeking 
to offset $offset " +
-          s"got offset ${record.offset} instead. If this is a compacted topic, 
consider enabling " +
-          "spark.streaming.kafka.allowNonConsecutiveOffsets"
-      )
-    }
-
-    nextOffset = offset + 1
-    record
-  }
-
-  /**
-   * Start a batch on a compacted topic
-   */
-  def compactedStart(offset: Long, timeout: Long): Unit = {
-    logDebug(s"compacted start $groupId $topic $partition starting $offset")
-    // This seek may not be necessary, but it's hard to tell due to gaps in 
compacted topics
-    if (offset != nextOffset) {
-      logInfo(s"Initial fetch for compacted $groupId $topic $partition 
$offset")
-      seek(offset)
-      poll(timeout)
-    }
-  }
-
-  /**
-   * Get the next record in the batch from a compacted topic.
-   * Assumes compactedStart has been called first, and ignores gaps.
-   */
-  def compactedNext(timeout: Long): ConsumerRecord[K, V] = {
-    if (!buffer.hasNext()) {
-      poll(timeout)
-    }
-    require(buffer.hasNext(),
-      s"Failed to get records for compacted $groupId $topic $partition after 
polling for $timeout")
-    val record = buffer.next()
-    nextOffset = record.offset + 1
-    record
-  }
-
-  /**
-   * Rewind to previous record in the batch from a compacted topic.
-   * @throws NoSuchElementException if no previous element
-   */
-  def compactedPrevious(): ConsumerRecord[K, V] = {
-    buffer.previous()
-  }
-
-  private def seek(offset: Long): Unit = {
-    logDebug(s"Seeking to $topicPartition $offset")
-    consumer.seek(topicPartition, offset)
-  }
-
-  private def poll(timeout: Long): Unit = {
-    val p = consumer.poll(timeout)
-    val r = p.records(topicPartition)
-    logDebug(s"Polled ${p.partitions()}  ${r.size}")
-    buffer = r.listIterator
-  }
-
-}
-
-private[kafka010]
-object CachedKafkaConsumer extends Logging {
-
-  private case class CacheKey(groupId: String, topic: String, partition: Int)
-
-  // Don't want to depend on guava, don't want a cleanup thread, use a simple 
LinkedHashMap
-  private var cache: ju.LinkedHashMap[CacheKey, CachedKafkaConsumer[_, _]] = 
null
-
-  /** Must be called before get, once per JVM, to configure the cache. Further 
calls are ignored */
-  def init(
-      initialCapacity: Int,
-      maxCapacity: Int,
-      loadFactor: Float): Unit = CachedKafkaConsumer.synchronized {
-    if (null == cache) {
-      logInfo(s"Initializing cache $initialCapacity $maxCapacity $loadFactor")
-      cache = new ju.LinkedHashMap[CacheKey, CachedKafkaConsumer[_, _]](
-        initialCapacity, loadFactor, true) {
-        override def removeEldestEntry(
-          entry: ju.Map.Entry[CacheKey, CachedKafkaConsumer[_, _]]): Boolean = 
{
-          if (this.size > maxCapacity) {
-            try {
-              entry.getValue.consumer.close()
-            } catch {
-              case x: KafkaException =>
-                logError("Error closing oldest Kafka consumer", x)
-            }
-            true
-          } else {
-            false
-          }
-        }
-      }
-    }
-  }
-
-  /**
-   * Get a cached consumer for groupId, assigned to topic and partition.
-   * If matching consumer doesn't already exist, will be created using 
kafkaParams.
-   */
-  def get[K, V](
-      groupId: String,
-      topic: String,
-      partition: Int,
-      kafkaParams: ju.Map[String, Object]): CachedKafkaConsumer[K, V] =
-    CachedKafkaConsumer.synchronized {
-      val k = CacheKey(groupId, topic, partition)
-      val v = cache.get(k)
-      if (null == v) {
-        logInfo(s"Cache miss for $k")
-        logDebug(cache.keySet.toString)
-        val c = new CachedKafkaConsumer[K, V](groupId, topic, partition, 
kafkaParams)
-        cache.put(k, c)
-        c
-      } else {
-        // any given topicpartition should have a consistent key and value type
-        v.asInstanceOf[CachedKafkaConsumer[K, V]]
-      }
-    }
-
-  /**
-   * Get a fresh new instance, unassociated with the global cache.
-   * Caller is responsible for closing
-   */
-  def getUncached[K, V](
-      groupId: String,
-      topic: String,
-      partition: Int,
-      kafkaParams: ju.Map[String, Object]): CachedKafkaConsumer[K, V] =
-    new CachedKafkaConsumer[K, V](groupId, topic, partition, kafkaParams)
-
-  /** remove consumer for given groupId, topic, and partition, if it exists */
-  def remove(groupId: String, topic: String, partition: Int): Unit = {
-    val k = CacheKey(groupId, topic, partition)
-    logInfo(s"Removing $k from cache")
-    val v = CachedKafkaConsumer.synchronized {
-      cache.remove(k)
-    }
-    if (null != v) {
-      v.close()
-      logInfo(s"Removed $k from cache")
-    }
-  }
-}

http://git-wip-us.apache.org/repos/asf/spark/blob/79e06faa/external/kafka-0-10/src/main/scala/org/apache/spark/streaming/kafka010/KafkaDataConsumer.scala
----------------------------------------------------------------------
diff --git 
a/external/kafka-0-10/src/main/scala/org/apache/spark/streaming/kafka010/KafkaDataConsumer.scala
 
b/external/kafka-0-10/src/main/scala/org/apache/spark/streaming/kafka010/KafkaDataConsumer.scala
new file mode 100644
index 0000000..68c5fe9
--- /dev/null
+++ 
b/external/kafka-0-10/src/main/scala/org/apache/spark/streaming/kafka010/KafkaDataConsumer.scala
@@ -0,0 +1,359 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.streaming.kafka010
+
+import java.{util => ju}
+
+import org.apache.kafka.clients.consumer.{ConsumerConfig, ConsumerRecord, 
KafkaConsumer}
+import org.apache.kafka.common.{KafkaException, TopicPartition}
+
+import org.apache.spark.TaskContext
+import org.apache.spark.internal.Logging
+
+private[kafka010] sealed trait KafkaDataConsumer[K, V] {
+  /**
+   * Get the record for the given offset if available.
+   *
+   * @param offset         the offset to fetch.
+   * @param pollTimeoutMs  timeout in milliseconds to poll data from Kafka.
+   */
+  def get(offset: Long, pollTimeoutMs: Long): ConsumerRecord[K, V] = {
+    internalConsumer.get(offset, pollTimeoutMs)
+  }
+
+  /**
+   * Start a batch on a compacted topic
+   *
+   * @param offset         the offset to fetch.
+   * @param pollTimeoutMs  timeout in milliseconds to poll data from Kafka.
+   */
+  def compactedStart(offset: Long, pollTimeoutMs: Long): Unit = {
+    internalConsumer.compactedStart(offset, pollTimeoutMs)
+  }
+
+  /**
+   * Get the next record in the batch from a compacted topic.
+   * Assumes compactedStart has been called first, and ignores gaps.
+   *
+   * @param pollTimeoutMs  timeout in milliseconds to poll data from Kafka.
+   */
+  def compactedNext(pollTimeoutMs: Long): ConsumerRecord[K, V] = {
+    internalConsumer.compactedNext(pollTimeoutMs)
+  }
+
+  /**
+   * Rewind to previous record in the batch from a compacted topic.
+   *
+   * @throws NoSuchElementException if no previous element
+   */
+  def compactedPrevious(): ConsumerRecord[K, V] = {
+    internalConsumer.compactedPrevious()
+  }
+
+  /**
+   * Release this consumer from being further used. Depending on its 
implementation,
+   * this consumer will be either finalized, or reset for reuse later.
+   */
+  def release(): Unit
+
+  /** Reference to the internal implementation that this wrapper delegates to 
*/
+  def internalConsumer: InternalKafkaConsumer[K, V]
+}
+
+
+/**
+ * A wrapper around Kafka's KafkaConsumer.
+ * This is not for direct use outside this file.
+ */
+private[kafka010] class InternalKafkaConsumer[K, V](
+    val topicPartition: TopicPartition,
+    val kafkaParams: ju.Map[String, Object]) extends Logging {
+
+  private[kafka010] val groupId = 
kafkaParams.get(ConsumerConfig.GROUP_ID_CONFIG)
+    .asInstanceOf[String]
+
+  private val consumer = createConsumer
+
+  /** indicates whether this consumer is in use or not */
+  var inUse = true
+
+  /** indicate whether this consumer is going to be stopped in the next 
release */
+  var markedForClose = false
+
+  // TODO if the buffer was kept around as a random-access structure,
+  // could possibly optimize re-calculating of an RDD in the same batch
+  @volatile private var buffer = 
ju.Collections.emptyListIterator[ConsumerRecord[K, V]]()
+  @volatile private var nextOffset = InternalKafkaConsumer.UNKNOWN_OFFSET
+
+  override def toString: String = {
+    "InternalKafkaConsumer(" +
+      s"hash=${Integer.toHexString(hashCode)}, " +
+      s"groupId=$groupId, " +
+      s"topicPartition=$topicPartition)"
+  }
+
+  /** Create a KafkaConsumer to fetch records for `topicPartition` */
+  private def createConsumer: KafkaConsumer[K, V] = {
+    val c = new KafkaConsumer[K, V](kafkaParams)
+    val topics = ju.Arrays.asList(topicPartition)
+    c.assign(topics)
+    c
+  }
+
+  def close(): Unit = consumer.close()
+
+  /**
+   * Get the record for the given offset, waiting up to timeout ms if IO is 
necessary.
+   * Sequential forward access will use buffers, but random access will be 
horribly inefficient.
+   */
+  def get(offset: Long, timeout: Long): ConsumerRecord[K, V] = {
+    logDebug(s"Get $groupId $topicPartition nextOffset $nextOffset requested 
$offset")
+    if (offset != nextOffset) {
+      logInfo(s"Initial fetch for $groupId $topicPartition $offset")
+      seek(offset)
+      poll(timeout)
+    }
+
+    if (!buffer.hasNext()) {
+      poll(timeout)
+    }
+    require(buffer.hasNext(),
+      s"Failed to get records for $groupId $topicPartition $offset after 
polling for $timeout")
+    var record = buffer.next()
+
+    if (record.offset != offset) {
+      logInfo(s"Buffer miss for $groupId $topicPartition $offset")
+      seek(offset)
+      poll(timeout)
+      require(buffer.hasNext(),
+        s"Failed to get records for $groupId $topicPartition $offset after 
polling for $timeout")
+      record = buffer.next()
+      require(record.offset == offset,
+        s"Got wrong record for $groupId $topicPartition even after seeking to 
offset $offset " +
+          s"got offset ${record.offset} instead. If this is a compacted topic, 
consider enabling " +
+          "spark.streaming.kafka.allowNonConsecutiveOffsets"
+      )
+    }
+
+    nextOffset = offset + 1
+    record
+  }
+
+  /**
+   * Start a batch on a compacted topic
+   */
+  def compactedStart(offset: Long, pollTimeoutMs: Long): Unit = {
+    logDebug(s"compacted start $groupId $topicPartition starting $offset")
+    // This seek may not be necessary, but it's hard to tell due to gaps in 
compacted topics
+    if (offset != nextOffset) {
+      logInfo(s"Initial fetch for compacted $groupId $topicPartition $offset")
+      seek(offset)
+      poll(pollTimeoutMs)
+    }
+  }
+
+  /**
+   * Get the next record in the batch from a compacted topic.
+   * Assumes compactedStart has been called first, and ignores gaps.
+   */
+  def compactedNext(pollTimeoutMs: Long): ConsumerRecord[K, V] = {
+    if (!buffer.hasNext()) {
+      poll(pollTimeoutMs)
+    }
+    require(buffer.hasNext(),
+      s"Failed to get records for compacted $groupId $topicPartition " +
+        s"after polling for $pollTimeoutMs")
+    val record = buffer.next()
+    nextOffset = record.offset + 1
+    record
+  }
+
+  /**
+   * Rewind to previous record in the batch from a compacted topic.
+   * @throws NoSuchElementException if no previous element
+   */
+  def compactedPrevious(): ConsumerRecord[K, V] = {
+    buffer.previous()
+  }
+
+  private def seek(offset: Long): Unit = {
+    logDebug(s"Seeking to $topicPartition $offset")
+    consumer.seek(topicPartition, offset)
+  }
+
+  private def poll(timeout: Long): Unit = {
+    val p = consumer.poll(timeout)
+    val r = p.records(topicPartition)
+    logDebug(s"Polled ${p.partitions()}  ${r.size}")
+    buffer = r.listIterator
+  }
+
+}
+
+private[kafka010] case class CacheKey(groupId: String, topicPartition: 
TopicPartition)
+
+private[kafka010] object KafkaDataConsumer extends Logging {
+
+  private case class CachedKafkaDataConsumer[K, V](internalConsumer: 
InternalKafkaConsumer[K, V])
+    extends KafkaDataConsumer[K, V] {
+    assert(internalConsumer.inUse)
+    override def release(): Unit = KafkaDataConsumer.release(internalConsumer)
+  }
+
+  private case class NonCachedKafkaDataConsumer[K, V](internalConsumer: 
InternalKafkaConsumer[K, V])
+    extends KafkaDataConsumer[K, V] {
+    override def release(): Unit = internalConsumer.close()
+  }
+
+  // Don't want to depend on guava, don't want a cleanup thread, use a simple 
LinkedHashMap
+  private[kafka010] var cache: ju.Map[CacheKey, InternalKafkaConsumer[_, _]] = 
null
+
+  /**
+   * Must be called before acquire, once per JVM, to configure the cache.
+   * Further calls are ignored.
+   */
+  def init(
+      initialCapacity: Int,
+      maxCapacity: Int,
+      loadFactor: Float): Unit = synchronized {
+    if (null == cache) {
+      logInfo(s"Initializing cache $initialCapacity $maxCapacity $loadFactor")
+      cache = new ju.LinkedHashMap[CacheKey, InternalKafkaConsumer[_, _]](
+        initialCapacity, loadFactor, true) {
+        override def removeEldestEntry(
+            entry: ju.Map.Entry[CacheKey, InternalKafkaConsumer[_, _]]): 
Boolean = {
+
+          // Try to remove the least-used entry if its currently not in use.
+          //
+          // If you cannot remove it, then the cache will keep growing. In the 
worst case,
+          // the cache will grow to the max number of concurrent tasks that 
can run in the executor,
+          // (that is, number of tasks slots) after which it will never 
reduce. This is unlikely to
+          // be a serious problem because an executor with more than 64 
(default) tasks slots is
+          // likely running on a beefy machine that can handle a large number 
of simultaneously
+          // active consumers.
+
+          if (entry.getValue.inUse == false && this.size > maxCapacity) {
+            logWarning(
+                s"KafkaConsumer cache hitting max capacity of $maxCapacity, " +
+                s"removing consumer for ${entry.getKey}")
+               try {
+              entry.getValue.close()
+            } catch {
+              case x: KafkaException =>
+                logError("Error closing oldest Kafka consumer", x)
+            }
+            true
+          } else {
+            false
+          }
+        }
+      }
+    }
+  }
+
+  /**
+   * Get a cached consumer for groupId, assigned to topic and partition.
+   * If matching consumer doesn't already exist, will be created using 
kafkaParams.
+   * The returned consumer must be released explicitly using 
[[KafkaDataConsumer.release()]].
+   *
+   * Note: This method guarantees that the consumer returned is not currently 
in use by anyone
+   * else. Within this guarantee, this method will make a best effort attempt 
to re-use consumers by
+   * caching them and tracking when they are in use.
+   */
+  def acquire[K, V](
+      topicPartition: TopicPartition,
+      kafkaParams: ju.Map[String, Object],
+      context: TaskContext,
+      useCache: Boolean): KafkaDataConsumer[K, V] = synchronized {
+    val groupId = 
kafkaParams.get(ConsumerConfig.GROUP_ID_CONFIG).asInstanceOf[String]
+    val key = new CacheKey(groupId, topicPartition)
+    val existingInternalConsumer = cache.get(key)
+
+    lazy val newInternalConsumer = new InternalKafkaConsumer[K, 
V](topicPartition, kafkaParams)
+
+    if (context != null && context.attemptNumber >= 1) {
+      // If this is reattempt at running the task, then invalidate cached 
consumers if any and
+      // start with a new one. If prior attempt failures were cache related 
then this way old
+      // problematic consumers can be removed.
+      logDebug(s"Reattempt detected, invalidating cached consumer 
$existingInternalConsumer")
+      if (existingInternalConsumer != null) {
+        // Consumer exists in cache. If its in use, mark it for closing later, 
or close it now.
+        if (existingInternalConsumer.inUse) {
+          existingInternalConsumer.markedForClose = true
+        } else {
+          existingInternalConsumer.close()
+          // Remove the consumer from cache only if it's closed.
+          // Marked for close consumers will be removed in release function.
+          cache.remove(key)
+        }
+      }
+
+      logDebug("Reattempt detected, new non-cached consumer will be allocated 
" +
+        s"$newInternalConsumer")
+      NonCachedKafkaDataConsumer(newInternalConsumer)
+    } else if (!useCache) {
+      // If consumer reuse turned off, then do not use it, return a new 
consumer
+      logDebug("Cache usage turned off, new non-cached consumer will be 
allocated " +
+        s"$newInternalConsumer")
+      NonCachedKafkaDataConsumer(newInternalConsumer)
+    } else if (existingInternalConsumer == null) {
+      // If consumer is not already cached, then put a new in the cache and 
return it
+      logDebug("No cached consumer, new cached consumer will be allocated " +
+        s"$newInternalConsumer")
+      cache.put(key, newInternalConsumer)
+      CachedKafkaDataConsumer(newInternalConsumer)
+    } else if (existingInternalConsumer.inUse) {
+      // If consumer is already cached but is currently in use, then return a 
new consumer
+      logDebug("Used cached consumer found, new non-cached consumer will be 
allocated " +
+        s"$newInternalConsumer")
+      NonCachedKafkaDataConsumer(newInternalConsumer)
+    } else {
+      // If consumer is already cached and is currently not in use, then 
return that consumer
+      logDebug(s"Not used cached consumer found, re-using it 
$existingInternalConsumer")
+      existingInternalConsumer.inUse = true
+      // Any given TopicPartition should have a consistent key and value type
+      
CachedKafkaDataConsumer(existingInternalConsumer.asInstanceOf[InternalKafkaConsumer[K,
 V]])
+    }
+  }
+
+  private def release(internalConsumer: InternalKafkaConsumer[_, _]): Unit = 
synchronized {
+    // Clear the consumer from the cache if this is indeed the consumer 
present in the cache
+    val key = new CacheKey(internalConsumer.groupId, 
internalConsumer.topicPartition)
+    val cachedInternalConsumer = cache.get(key)
+    if (internalConsumer.eq(cachedInternalConsumer)) {
+      // The released consumer is the same object as the cached one.
+      if (internalConsumer.markedForClose) {
+        internalConsumer.close()
+        cache.remove(key)
+      } else {
+        internalConsumer.inUse = false
+      }
+    } else {
+      // The released consumer is either not the same one as in the cache, or 
not in the cache
+      // at all. This may happen if the cache was invalidate while this 
consumer was being used.
+      // Just close this consumer.
+      internalConsumer.close()
+      logInfo(s"Released a supposedly cached consumer that was not found in 
the cache " +
+        s"$internalConsumer")
+    }
+  }
+}
+
+private[kafka010] object InternalKafkaConsumer {
+  private val UNKNOWN_OFFSET = -2L
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/79e06faa/external/kafka-0-10/src/main/scala/org/apache/spark/streaming/kafka010/KafkaRDD.scala
----------------------------------------------------------------------
diff --git 
a/external/kafka-0-10/src/main/scala/org/apache/spark/streaming/kafka010/KafkaRDD.scala
 
b/external/kafka-0-10/src/main/scala/org/apache/spark/streaming/kafka010/KafkaRDD.scala
index 07239ed..81abc98 100644
--- 
a/external/kafka-0-10/src/main/scala/org/apache/spark/streaming/kafka010/KafkaRDD.scala
+++ 
b/external/kafka-0-10/src/main/scala/org/apache/spark/streaming/kafka010/KafkaRDD.scala
@@ -19,8 +19,6 @@ package org.apache.spark.streaming.kafka010
 
 import java.{ util => ju }
 
-import scala.collection.mutable.ArrayBuffer
-
 import org.apache.kafka.clients.consumer.{ ConsumerConfig, ConsumerRecord }
 import org.apache.kafka.common.TopicPartition
 
@@ -239,26 +237,18 @@ private class KafkaRDDIterator[K, V](
   cacheLoadFactor: Float
 ) extends Iterator[ConsumerRecord[K, V]] {
 
-  val groupId = 
kafkaParams.get(ConsumerConfig.GROUP_ID_CONFIG).asInstanceOf[String]
-
   context.addTaskCompletionListener(_ => closeIfNeeded())
 
-  val consumer = if (useConsumerCache) {
-    CachedKafkaConsumer.init(cacheInitialCapacity, cacheMaxCapacity, 
cacheLoadFactor)
-    if (context.attemptNumber >= 1) {
-      // just in case the prior attempt failures were cache related
-      CachedKafkaConsumer.remove(groupId, part.topic, part.partition)
-    }
-    CachedKafkaConsumer.get[K, V](groupId, part.topic, part.partition, 
kafkaParams)
-  } else {
-    CachedKafkaConsumer.getUncached[K, V](groupId, part.topic, part.partition, 
kafkaParams)
+  val consumer = {
+    KafkaDataConsumer.init(cacheInitialCapacity, cacheMaxCapacity, 
cacheLoadFactor)
+    KafkaDataConsumer.acquire[K, V](part.topicPartition(), kafkaParams, 
context, useConsumerCache)
   }
 
   var requestOffset = part.fromOffset
 
   def closeIfNeeded(): Unit = {
-    if (!useConsumerCache && consumer != null) {
-      consumer.close()
+    if (consumer != null) {
+      consumer.release()
     }
   }
 

http://git-wip-us.apache.org/repos/asf/spark/blob/79e06faa/external/kafka-0-10/src/test/scala/org/apache/spark/streaming/kafka010/KafkaDataConsumerSuite.scala
----------------------------------------------------------------------
diff --git 
a/external/kafka-0-10/src/test/scala/org/apache/spark/streaming/kafka010/KafkaDataConsumerSuite.scala
 
b/external/kafka-0-10/src/test/scala/org/apache/spark/streaming/kafka010/KafkaDataConsumerSuite.scala
new file mode 100644
index 0000000..d934c64
--- /dev/null
+++ 
b/external/kafka-0-10/src/test/scala/org/apache/spark/streaming/kafka010/KafkaDataConsumerSuite.scala
@@ -0,0 +1,131 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.streaming.kafka010
+
+import java.util.concurrent.{Executors, TimeUnit}
+
+import scala.collection.JavaConverters._
+import scala.util.Random
+
+import org.apache.kafka.clients.consumer.ConsumerConfig._
+import org.apache.kafka.common.TopicPartition
+import org.apache.kafka.common.serialization.ByteArrayDeserializer
+import org.scalatest.BeforeAndAfterAll
+
+import org.apache.spark._
+
+class KafkaDataConsumerSuite extends SparkFunSuite with BeforeAndAfterAll {
+  private var testUtils: KafkaTestUtils = _
+  private val topic = "topic" + Random.nextInt()
+  private val topicPartition = new TopicPartition(topic, 0)
+  private val groupId = "groupId"
+
+  override def beforeAll(): Unit = {
+    super.beforeAll()
+    testUtils = new KafkaTestUtils
+    testUtils.setup()
+    KafkaDataConsumer.init(16, 64, 0.75f)
+  }
+
+  override def afterAll(): Unit = {
+    if (testUtils != null) {
+      testUtils.teardown()
+      testUtils = null
+    }
+    super.afterAll()
+  }
+
+  private def getKafkaParams() = Map[String, Object](
+    GROUP_ID_CONFIG -> groupId,
+    BOOTSTRAP_SERVERS_CONFIG -> testUtils.brokerAddress,
+    KEY_DESERIALIZER_CLASS_CONFIG -> classOf[ByteArrayDeserializer].getName,
+    VALUE_DESERIALIZER_CLASS_CONFIG -> classOf[ByteArrayDeserializer].getName,
+    AUTO_OFFSET_RESET_CONFIG -> "earliest",
+    ENABLE_AUTO_COMMIT_CONFIG -> "false"
+  ).asJava
+
+  test("KafkaDataConsumer reuse in case of same groupId and TopicPartition") {
+    KafkaDataConsumer.cache.clear()
+
+    val kafkaParams = getKafkaParams()
+
+    val consumer1 = KafkaDataConsumer.acquire[Array[Byte], Array[Byte]](
+      topicPartition, kafkaParams, null, true)
+    consumer1.release()
+
+    val consumer2 = KafkaDataConsumer.acquire[Array[Byte], Array[Byte]](
+      topicPartition, kafkaParams, null, true)
+    consumer2.release()
+
+    assert(KafkaDataConsumer.cache.size() == 1)
+    val key = new CacheKey(groupId, topicPartition)
+    val existingInternalConsumer = KafkaDataConsumer.cache.get(key)
+    assert(existingInternalConsumer.eq(consumer1.internalConsumer))
+    assert(existingInternalConsumer.eq(consumer2.internalConsumer))
+  }
+
+  test("concurrent use of KafkaDataConsumer") {
+    val data = (1 to 1000).map(_.toString)
+    testUtils.createTopic(topic)
+    testUtils.sendMessages(topic, data.toArray)
+
+    val kafkaParams = getKafkaParams()
+
+    val numThreads = 100
+    val numConsumerUsages = 500
+
+    @volatile var error: Throwable = null
+
+    def consume(i: Int): Unit = {
+      val useCache = Random.nextBoolean
+      val taskContext = if (Random.nextBoolean) {
+        new TaskContextImpl(0, 0, 0, 0, attemptNumber = Random.nextInt(2), 
null, null, null)
+      } else {
+        null
+      }
+      val consumer = KafkaDataConsumer.acquire[Array[Byte], Array[Byte]](
+        topicPartition, kafkaParams, taskContext, useCache)
+      try {
+        val rcvd = (0 until data.length).map { offset =>
+          val bytes = consumer.get(offset, 10000).value()
+          new String(bytes)
+        }
+        assert(rcvd == data)
+      } catch {
+        case e: Throwable =>
+          error = e
+          throw e
+      } finally {
+        consumer.release()
+      }
+    }
+
+    val threadPool = Executors.newFixedThreadPool(numThreads)
+    try {
+      val futures = (1 to numConsumerUsages).map { i =>
+        threadPool.submit(new Runnable {
+          override def run(): Unit = { consume(i) }
+        })
+      }
+      futures.foreach(_.get(1, TimeUnit.MINUTES))
+      assert(error == null)
+    } finally {
+      threadPool.shutdown()
+    }
+  }
+}


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to