Repository: spark
Updated Branches:
  refs/heads/master 1c7db00c7 -> 96a4d1d08


[SPARK-19968][SS] Use a cached instance of `KafkaProducer` instead of creating 
one every batch.

## What changes were proposed in this pull request?

In summary, cost of recreating a KafkaProducer for writing every batch is high 
as it starts a lot threads and make connections and then closes them. A 
KafkaProducer instance is promised to be thread safe in Kafka docs. Reuse of 
KafkaProducer instance while writing via multiple threads is encouraged.

Furthermore, I have performance improvement of 10x in latency, with this patch.

### These are times that addBatch took in ms. Without applying this patch
![with-out_patch](https://cloud.githubusercontent.com/assets/992952/23994612/a9de4a42-0a6b-11e7-9d5b-7ae18775bee4.png)
### These are times that addBatch took in ms. After applying this patch
![with_patch](https://cloud.githubusercontent.com/assets/992952/23994616/ad8c11ec-0a6b-11e7-8634-2266ebb5033f.png)

## How was this patch tested?
Running distributed benchmarks comparing runs with this patch and without it.
Added relevant unit tests.

Author: Prashant Sharma <prash...@in.ibm.com>

Closes #17308 from ScrapCodes/cached-kafka-producer.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/96a4d1d0
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/96a4d1d0
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/96a4d1d0

Branch: refs/heads/master
Commit: 96a4d1d0827fc3fba83f174510b061684f0d00f7
Parents: 1c7db00
Author: Prashant Sharma <prash...@in.ibm.com>
Authored: Mon May 29 18:12:01 2017 -0700
Committer: Shixiong Zhu <shixi...@databricks.com>
Committed: Mon May 29 18:12:01 2017 -0700

----------------------------------------------------------------------
 .../sql/kafka010/CachedKafkaProducer.scala      | 112 +++++++++++++++++++
 .../apache/spark/sql/kafka010/KafkaSource.scala |  14 +--
 .../spark/sql/kafka010/KafkaWriteTask.scala     |  17 ++-
 .../apache/spark/sql/kafka010/KafkaWriter.scala |   3 +-
 .../sql/kafka010/CachedKafkaProducerSuite.scala |  78 +++++++++++++
 5 files changed, 206 insertions(+), 18 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/96a4d1d0/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/CachedKafkaProducer.scala
----------------------------------------------------------------------
diff --git 
a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/CachedKafkaProducer.scala
 
b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/CachedKafkaProducer.scala
new file mode 100644
index 0000000..571140b
--- /dev/null
+++ 
b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/CachedKafkaProducer.scala
@@ -0,0 +1,112 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.kafka010
+
+import java.{util => ju}
+import java.util.concurrent.{ConcurrentMap, ExecutionException, TimeUnit}
+
+import com.google.common.cache._
+import com.google.common.util.concurrent.{ExecutionError, 
UncheckedExecutionException}
+import org.apache.kafka.clients.producer.KafkaProducer
+import scala.collection.JavaConverters._
+import scala.util.control.NonFatal
+
+import org.apache.spark.SparkEnv
+import org.apache.spark.internal.Logging
+
+private[kafka010] object CachedKafkaProducer extends Logging {
+
+  private type Producer = KafkaProducer[Array[Byte], Array[Byte]]
+
+  private lazy val cacheExpireTimeout: Long =
+    SparkEnv.get.conf.getTimeAsMs("spark.kafka.producer.cache.timeout", "10m")
+
+  private val cacheLoader = new CacheLoader[Seq[(String, Object)], Producer] {
+    override def load(config: Seq[(String, Object)]): Producer = {
+      val configMap = config.map(x => x._1 -> x._2).toMap.asJava
+      createKafkaProducer(configMap)
+    }
+  }
+
+  private val removalListener = new RemovalListener[Seq[(String, Object)], 
Producer]() {
+    override def onRemoval(
+        notification: RemovalNotification[Seq[(String, Object)], Producer]): 
Unit = {
+      val paramsSeq: Seq[(String, Object)] = notification.getKey
+      val producer: Producer = notification.getValue
+      logDebug(
+        s"Evicting kafka producer $producer params: $paramsSeq, due to 
${notification.getCause}")
+      close(paramsSeq, producer)
+    }
+  }
+
+  private lazy val guavaCache: LoadingCache[Seq[(String, Object)], Producer] =
+    CacheBuilder.newBuilder().expireAfterAccess(cacheExpireTimeout, 
TimeUnit.MILLISECONDS)
+      .removalListener(removalListener)
+      .build[Seq[(String, Object)], Producer](cacheLoader)
+
+  private def createKafkaProducer(producerConfiguration: ju.Map[String, 
Object]): Producer = {
+    val kafkaProducer: Producer = new Producer(producerConfiguration)
+    logDebug(s"Created a new instance of KafkaProducer for 
$producerConfiguration.")
+    kafkaProducer
+  }
+
+  /**
+   * Get a cached KafkaProducer for a given configuration. If matching 
KafkaProducer doesn't
+   * exist, a new KafkaProducer will be created. KafkaProducer is thread safe, 
it is best to keep
+   * one instance per specified kafkaParams.
+   */
+  private[kafka010] def getOrCreate(kafkaParams: ju.Map[String, Object]): 
Producer = {
+    val paramsSeq: Seq[(String, Object)] = paramsToSeq(kafkaParams)
+    try {
+      guavaCache.get(paramsSeq)
+    } catch {
+      case e @ (_: ExecutionException | _: UncheckedExecutionException | _: 
ExecutionError)
+        if e.getCause != null =>
+        throw e.getCause
+    }
+  }
+
+  private def paramsToSeq(kafkaParams: ju.Map[String, Object]): Seq[(String, 
Object)] = {
+    val paramsSeq: Seq[(String, Object)] = kafkaParams.asScala.toSeq.sortBy(x 
=> x._1)
+    paramsSeq
+  }
+
+  /** For explicitly closing kafka producer */
+  private[kafka010] def close(kafkaParams: ju.Map[String, Object]): Unit = {
+    val paramsSeq = paramsToSeq(kafkaParams)
+    guavaCache.invalidate(paramsSeq)
+  }
+
+  /** Auto close on cache evict */
+  private def close(paramsSeq: Seq[(String, Object)], producer: Producer): 
Unit = {
+    try {
+      logInfo(s"Closing the KafkaProducer with params: 
${paramsSeq.mkString("\n")}.")
+      producer.close()
+    } catch {
+      case NonFatal(e) => logWarning("Error while closing kafka producer.", e)
+    }
+  }
+
+  private def clear(): Unit = {
+    logInfo("Cleaning up guava cache.")
+    guavaCache.invalidateAll()
+  }
+
+  // Intended for testing purpose only.
+  private def getAsMap: ConcurrentMap[Seq[(String, Object)], Producer] = 
guavaCache.asMap()
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/96a4d1d0/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSource.scala
----------------------------------------------------------------------
diff --git 
a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSource.scala
 
b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSource.scala
index 1fb0a33..7ac1837 100644
--- 
a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSource.scala
+++ 
b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSource.scala
@@ -70,13 +70,13 @@ import org.apache.spark.unsafe.types.UTF8String
  * and not use wrong broker addresses.
  */
 private[kafka010] class KafkaSource(
-                                     sqlContext: SQLContext,
-                                     kafkaReader: KafkaOffsetReader,
-                                     executorKafkaParams: ju.Map[String, 
Object],
-                                     sourceOptions: Map[String, String],
-                                     metadataPath: String,
-                                     startingOffsets: KafkaOffsetRangeLimit,
-                                     failOnDataLoss: Boolean)
+    sqlContext: SQLContext,
+    kafkaReader: KafkaOffsetReader,
+    executorKafkaParams: ju.Map[String, Object],
+    sourceOptions: Map[String, String],
+    metadataPath: String,
+    startingOffsets: KafkaOffsetRangeLimit,
+    failOnDataLoss: Boolean)
   extends Source with Logging {
 
   private val sc = sqlContext.sparkContext

http://git-wip-us.apache.org/repos/asf/spark/blob/96a4d1d0/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaWriteTask.scala
----------------------------------------------------------------------
diff --git 
a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaWriteTask.scala
 
b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaWriteTask.scala
index 6e160cb..6fd333e 100644
--- 
a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaWriteTask.scala
+++ 
b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaWriteTask.scala
@@ -19,8 +19,7 @@ package org.apache.spark.sql.kafka010
 
 import java.{util => ju}
 
-import org.apache.kafka.clients.producer.{KafkaProducer, _}
-import org.apache.kafka.common.serialization.ByteArraySerializer
+import org.apache.kafka.clients.producer.{Callback, KafkaProducer, 
ProducerRecord, RecordMetadata}
 
 import org.apache.spark.sql.catalyst.InternalRow
 import org.apache.spark.sql.catalyst.expressions.{Attribute, Cast, Literal, 
UnsafeProjection}
@@ -44,7 +43,7 @@ private[kafka010] class KafkaWriteTask(
    * Writes key value data out to topics.
    */
   def execute(iterator: Iterator[InternalRow]): Unit = {
-    producer = new KafkaProducer[Array[Byte], 
Array[Byte]](producerConfiguration)
+    producer = CachedKafkaProducer.getOrCreate(producerConfiguration)
     while (iterator.hasNext && failedWrite == null) {
       val currentRow = iterator.next()
       val projectedRow = projection(currentRow)
@@ -68,10 +67,10 @@ private[kafka010] class KafkaWriteTask(
   }
 
   def close(): Unit = {
+    checkForErrors()
     if (producer != null) {
-      checkForErrors
-      producer.close()
-      checkForErrors
+      producer.flush()
+      checkForErrors()
       producer = null
     }
   }
@@ -88,7 +87,7 @@ private[kafka010] class KafkaWriteTask(
       case t =>
         throw new IllegalStateException(s"${KafkaWriter.TOPIC_ATTRIBUTE_NAME} 
" +
           s"attribute unsupported type $t. ${KafkaWriter.TOPIC_ATTRIBUTE_NAME} 
" +
-          s"must be a ${StringType}")
+          "must be a StringType")
     }
     val keyExpression = inputSchema.find(_.name == 
KafkaWriter.KEY_ATTRIBUTE_NAME)
       .getOrElse(Literal(null, BinaryType))
@@ -100,7 +99,7 @@ private[kafka010] class KafkaWriteTask(
     }
     val valueExpression = inputSchema
       .find(_.name == KafkaWriter.VALUE_ATTRIBUTE_NAME).getOrElse(
-      throw new IllegalStateException(s"Required attribute " +
+      throw new IllegalStateException("Required attribute " +
         s"'${KafkaWriter.VALUE_ATTRIBUTE_NAME}' not found")
     )
     valueExpression.dataType match {
@@ -114,7 +113,7 @@ private[kafka010] class KafkaWriteTask(
         Cast(valueExpression, BinaryType)), inputSchema)
   }
 
-  private def checkForErrors: Unit = {
+  private def checkForErrors(): Unit = {
     if (failedWrite != null) {
       throw failedWrite
     }

http://git-wip-us.apache.org/repos/asf/spark/blob/96a4d1d0/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaWriter.scala
----------------------------------------------------------------------
diff --git 
a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaWriter.scala
 
b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaWriter.scala
index 61936e3..0ed9d4e 100644
--- 
a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaWriter.scala
+++ 
b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaWriter.scala
@@ -21,7 +21,6 @@ import java.{util => ju}
 
 import org.apache.spark.internal.Logging
 import org.apache.spark.sql.{AnalysisException, SparkSession}
-import org.apache.spark.sql.catalyst.InternalRow
 import org.apache.spark.sql.catalyst.expressions._
 import org.apache.spark.sql.execution.{QueryExecution, SQLExecution}
 import org.apache.spark.sql.types.{BinaryType, StringType}
@@ -49,7 +48,7 @@ private[kafka010] object KafkaWriter extends Logging {
       topic: Option[String] = None): Unit = {
     val schema = queryExecution.analyzed.output
     schema.find(_.name == TOPIC_ATTRIBUTE_NAME).getOrElse(
-      if (topic == None) {
+      if (topic.isEmpty) {
         throw new AnalysisException(s"topic option required when no " +
           s"'$TOPIC_ATTRIBUTE_NAME' attribute is present. Use the " +
           s"${KafkaSourceProvider.TOPIC_OPTION_KEY} option for setting a 
topic.")

http://git-wip-us.apache.org/repos/asf/spark/blob/96a4d1d0/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/CachedKafkaProducerSuite.scala
----------------------------------------------------------------------
diff --git 
a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/CachedKafkaProducerSuite.scala
 
b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/CachedKafkaProducerSuite.scala
new file mode 100644
index 0000000..789bffa
--- /dev/null
+++ 
b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/CachedKafkaProducerSuite.scala
@@ -0,0 +1,78 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.kafka010
+
+import java.{util => ju}
+import java.util.concurrent.ConcurrentMap
+
+import org.apache.kafka.clients.producer.KafkaProducer
+import org.apache.kafka.common.serialization.ByteArraySerializer
+import org.scalatest.PrivateMethodTester
+
+import org.apache.spark.sql.test.SharedSQLContext
+
+class CachedKafkaProducerSuite extends SharedSQLContext with 
PrivateMethodTester {
+
+  type KP = KafkaProducer[Array[Byte], Array[Byte]]
+
+  protected override def beforeEach(): Unit = {
+    super.beforeEach()
+    val clear = PrivateMethod[Unit]('clear)
+    CachedKafkaProducer.invokePrivate(clear())
+  }
+
+  test("Should return the cached instance on calling getOrCreate with same 
params.") {
+    val kafkaParams = new ju.HashMap[String, Object]()
+    kafkaParams.put("acks", "0")
+    // Here only host should be resolvable, it does not need a running 
instance of kafka server.
+    kafkaParams.put("bootstrap.servers", "127.0.0.1:9022")
+    kafkaParams.put("key.serializer", classOf[ByteArraySerializer].getName)
+    kafkaParams.put("value.serializer", classOf[ByteArraySerializer].getName)
+    val producer = CachedKafkaProducer.getOrCreate(kafkaParams)
+    val producer2 = CachedKafkaProducer.getOrCreate(kafkaParams)
+    assert(producer == producer2)
+
+    val cacheMap = PrivateMethod[ConcurrentMap[Seq[(String, Object)], 
KP]]('getAsMap)
+    val map = CachedKafkaProducer.invokePrivate(cacheMap())
+    assert(map.size == 1)
+  }
+
+  test("Should close the correct kafka producer for the given kafkaPrams.") {
+    val kafkaParams = new ju.HashMap[String, Object]()
+    kafkaParams.put("acks", "0")
+    kafkaParams.put("bootstrap.servers", "127.0.0.1:9022")
+    kafkaParams.put("key.serializer", classOf[ByteArraySerializer].getName)
+    kafkaParams.put("value.serializer", classOf[ByteArraySerializer].getName)
+    val producer: KP = CachedKafkaProducer.getOrCreate(kafkaParams)
+    kafkaParams.put("acks", "1")
+    val producer2: KP = CachedKafkaProducer.getOrCreate(kafkaParams)
+    // With updated conf, a new producer instance should be created.
+    assert(producer != producer2)
+
+    val cacheMap = PrivateMethod[ConcurrentMap[Seq[(String, Object)], 
KP]]('getAsMap)
+    val map = CachedKafkaProducer.invokePrivate(cacheMap())
+    assert(map.size == 2)
+
+    CachedKafkaProducer.close(kafkaParams)
+    val map2 = CachedKafkaProducer.invokePrivate(cacheMap())
+    assert(map2.size == 1)
+    import scala.collection.JavaConverters._
+    val (seq: Seq[(String, Object)], _producer: KP) = 
map2.asScala.toArray.apply(0)
+    assert(_producer == producer)
+  }
+}


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to