Repository: spark
Updated Branches:
  refs/heads/master 21fdfd7d6 -> 0a078303d


[SPARK-9556] [SPARK-9619] [SPARK-9624] [STREAMING] Make BlockGenerator more 
robust and make all BlockGenerators subscribe to rate limit updates

In some receivers, instead of using the default `BlockGenerator` in 
`ReceiverSupervisorImpl`, custom generator with their custom listeners are used 
for reliability (see 
[`ReliableKafkaReceiver`](https://github.com/apache/spark/blob/master/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/ReliableKafkaReceiver.scala#L99)
 and [updated 
`KinesisReceiver`](https://github.com/apache/spark/pull/7825/files)). These 
custom generators do not receive rate updates. This PR modifies the code to 
allow custom `BlockGenerator`s to be created through the 
`ReceiverSupervisorImpl` so that they can be kept track and rate updates can be 
applied.

In the process, I did some simplification, and de-flaki-fication of some rate 
controller related tests. In particular.
- Renamed `Receiver.executor` to `Receiver.supervisor` (to match 
`ReceiverSupervisor`)
- Made `RateControllerSuite` faster (by increasing batch interval) and less 
flaky
- Changed a few internal API to return the current rate of block generators as 
Long instead of Option\[Long\] (was inconsistent at places).
- Updated existing `ReceiverTrackerSuite` to test that custom block generators 
get rate updates as well.

Author: Tathagata Das <tathagata.das1...@gmail.com>

Closes #7913 from tdas/SPARK-9556 and squashes the following commits:

41d4461 [Tathagata Das] fix scala style
eb9fd59 [Tathagata Das] Updated kinesis receiver
d24994d [Tathagata Das] Updated BlockGeneratorSuite to use manual clock in 
BlockGenerator
d70608b [Tathagata Das] Updated BlockGenerator with states and proper 
synchronization
f6bd47e [Tathagata Das] Merge remote-tracking branch 'apache-github/master' 
into SPARK-9556
31da173 [Tathagata Das] Fix bug
12116df [Tathagata Das] Add BlockGeneratorSuite
74bd069 [Tathagata Das] Fix style
989bb5c [Tathagata Das] Made BlockGenerator fail is used after stop, and added 
better unit tests for it
3ff618c [Tathagata Das] Fix test
b40eff8 [Tathagata Das] slight refactoring
f0df0f1 [Tathagata Das] Scala style fixes
51759cb [Tathagata Das] Refactored rate controller tests and added the ability 
to update rate of any custom block generator


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/0a078303
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/0a078303
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/0a078303

Branch: refs/heads/master
Commit: 0a078303d08ad2bb92b9a8a6969563d75b512290
Parents: 21fdfd7
Author: Tathagata Das <tathagata.das1...@gmail.com>
Authored: Thu Aug 6 14:35:30 2015 -0700
Committer: Tathagata Das <tathagata.das1...@gmail.com>
Committed: Thu Aug 6 14:35:30 2015 -0700

----------------------------------------------------------------------
 .../org/apache/spark/util/ManualClock.scala     |   2 +-
 .../streaming/kafka/ReliableKafkaReceiver.scala |   2 +-
 .../streaming/kinesis/KinesisReceiver.scala     |   2 +-
 .../streaming/receiver/ActorReceiver.scala      |   8 +-
 .../streaming/receiver/BlockGenerator.scala     | 131 +++++++---
 .../spark/streaming/receiver/RateLimiter.scala  |   3 +-
 .../spark/streaming/receiver/Receiver.scala     |  52 ++--
 .../streaming/receiver/ReceiverSupervisor.scala |  27 +-
 .../receiver/ReceiverSupervisorImpl.scala       |  33 ++-
 .../spark/streaming/CheckpointSuite.scala       |  16 +-
 .../apache/spark/streaming/ReceiverSuite.scala  |  31 +--
 .../receiver/BlockGeneratorSuite.scala          | 253 +++++++++++++++++++
 .../scheduler/RateControllerSuite.scala         |  64 ++---
 .../scheduler/ReceiverTrackerSuite.scala        | 129 +++++-----
 14 files changed, 534 insertions(+), 219 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/0a078303/core/src/main/scala/org/apache/spark/util/ManualClock.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/util/ManualClock.scala 
b/core/src/main/scala/org/apache/spark/util/ManualClock.scala
index 1718554..e7a65d7 100644
--- a/core/src/main/scala/org/apache/spark/util/ManualClock.scala
+++ b/core/src/main/scala/org/apache/spark/util/ManualClock.scala
@@ -58,7 +58,7 @@ private[spark] class ManualClock(private var time: Long) 
extends Clock {
    */
   def waitTillTime(targetTime: Long): Long = synchronized {
     while (time < targetTime) {
-      wait(100)
+      wait(10)
     }
     getTimeMillis()
   }

http://git-wip-us.apache.org/repos/asf/spark/blob/0a078303/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/ReliableKafkaReceiver.scala
----------------------------------------------------------------------
diff --git 
a/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/ReliableKafkaReceiver.scala
 
b/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/ReliableKafkaReceiver.scala
index 75f0dfc..764d170 100644
--- 
a/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/ReliableKafkaReceiver.scala
+++ 
b/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/ReliableKafkaReceiver.scala
@@ -96,7 +96,7 @@ class ReliableKafkaReceiver[
     blockOffsetMap = new ConcurrentHashMap[StreamBlockId, 
Map[TopicAndPartition, Long]]()
 
     // Initialize the block generator for storing Kafka message.
-    blockGenerator = new BlockGenerator(new GeneratedBlockHandler, streamId, 
conf)
+    blockGenerator = supervisor.createBlockGenerator(new GeneratedBlockHandler)
 
     if (kafkaParams.contains(AUTO_OFFSET_COMMIT) && 
kafkaParams(AUTO_OFFSET_COMMIT) == "true") {
       logWarning(s"$AUTO_OFFSET_COMMIT should be set to false in 
ReliableKafkaReceiver, " +

http://git-wip-us.apache.org/repos/asf/spark/blob/0a078303/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisReceiver.scala
----------------------------------------------------------------------
diff --git 
a/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisReceiver.scala
 
b/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisReceiver.scala
index a4baeec..22324e8 100644
--- 
a/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisReceiver.scala
+++ 
b/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisReceiver.scala
@@ -136,7 +136,7 @@ private[kinesis] class KinesisReceiver(
    * The KCL creates and manages the receiving/processing thread pool through 
Worker.run().
    */
   override def onStart() {
-    blockGenerator = new BlockGenerator(new GeneratedBlockHandler, streamId, 
SparkEnv.get.conf)
+    blockGenerator = supervisor.createBlockGenerator(new GeneratedBlockHandler)
 
     workerId = Utils.localHostName() + ":" + UUID.randomUUID()
 

http://git-wip-us.apache.org/repos/asf/spark/blob/0a078303/streaming/src/main/scala/org/apache/spark/streaming/receiver/ActorReceiver.scala
----------------------------------------------------------------------
diff --git 
a/streaming/src/main/scala/org/apache/spark/streaming/receiver/ActorReceiver.scala
 
b/streaming/src/main/scala/org/apache/spark/streaming/receiver/ActorReceiver.scala
index cd30978..7ec7401 100644
--- 
a/streaming/src/main/scala/org/apache/spark/streaming/receiver/ActorReceiver.scala
+++ 
b/streaming/src/main/scala/org/apache/spark/streaming/receiver/ActorReceiver.scala
@@ -144,7 +144,7 @@ private[streaming] class ActorReceiver[T: ClassTag](
     receiverSupervisorStrategy: SupervisorStrategy
   ) extends Receiver[T](storageLevel) with Logging {
 
-  protected lazy val supervisor = SparkEnv.get.actorSystem.actorOf(Props(new 
Supervisor),
+  protected lazy val actorSupervisor = 
SparkEnv.get.actorSystem.actorOf(Props(new Supervisor),
     "Supervisor" + streamId)
 
   class Supervisor extends Actor {
@@ -191,11 +191,11 @@ private[streaming] class ActorReceiver[T: ClassTag](
   }
 
   def onStart(): Unit = {
-    supervisor
-    logInfo("Supervision tree for receivers initialized at:" + supervisor.path)
+    actorSupervisor
+    logInfo("Supervision tree for receivers initialized at:" + 
actorSupervisor.path)
   }
 
   def onStop(): Unit = {
-    supervisor ! PoisonPill
+    actorSupervisor ! PoisonPill
   }
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/0a078303/streaming/src/main/scala/org/apache/spark/streaming/receiver/BlockGenerator.scala
----------------------------------------------------------------------
diff --git 
a/streaming/src/main/scala/org/apache/spark/streaming/receiver/BlockGenerator.scala
 
b/streaming/src/main/scala/org/apache/spark/streaming/receiver/BlockGenerator.scala
index 92b51ce..794dece 100644
--- 
a/streaming/src/main/scala/org/apache/spark/streaming/receiver/BlockGenerator.scala
+++ 
b/streaming/src/main/scala/org/apache/spark/streaming/receiver/BlockGenerator.scala
@@ -21,10 +21,10 @@ import java.util.concurrent.{ArrayBlockingQueue, TimeUnit}
 
 import scala.collection.mutable.ArrayBuffer
 
-import org.apache.spark.{Logging, SparkConf}
+import org.apache.spark.{SparkException, Logging, SparkConf}
 import org.apache.spark.storage.StreamBlockId
 import org.apache.spark.streaming.util.RecurringTimer
-import org.apache.spark.util.SystemClock
+import org.apache.spark.util.{Clock, SystemClock}
 
 /** Listener object for BlockGenerator events */
 private[streaming] trait BlockGeneratorListener {
@@ -69,16 +69,35 @@ private[streaming] trait BlockGeneratorListener {
  * named blocks at regular intervals. This class starts two threads,
  * one to periodically start a new batch and prepare the previous batch of as 
a block,
  * the other to push the blocks into the block manager.
+ *
+ * Note: Do not create BlockGenerator instances directly inside receivers. Use
+ * `ReceiverSupervisor.createBlockGenerator` to create a BlockGenerator and 
use it.
  */
 private[streaming] class BlockGenerator(
     listener: BlockGeneratorListener,
     receiverId: Int,
-    conf: SparkConf
+    conf: SparkConf,
+    clock: Clock = new SystemClock()
   ) extends RateLimiter(conf) with Logging {
 
   private case class Block(id: StreamBlockId, buffer: ArrayBuffer[Any])
 
-  private val clock = new SystemClock()
+  /**
+   * The BlockGenerator can be in 5 possible states, in the order as follows.
+   * - Initialized: Nothing has been started
+   * - Active: start() has been called, and it is generating blocks on added 
data.
+   * - StoppedAddingData: stop() has been called, the adding of data has been 
stopped,
+   *                      but blocks are still being generated and pushed.
+   * - StoppedGeneratingBlocks: Generating of blocks has been stopped, but
+   *                            they are still being pushed.
+   * - StoppedAll: Everything has stopped, and the BlockGenerator object can 
be GCed.
+   */
+  private object GeneratorState extends Enumeration {
+    type GeneratorState = Value
+    val Initialized, Active, StoppedAddingData, StoppedGeneratingBlocks, 
StoppedAll = Value
+  }
+  import GeneratorState._
+
   private val blockIntervalMs = 
conf.getTimeAsMs("spark.streaming.blockInterval", "200ms")
   require(blockIntervalMs > 0, s"'spark.streaming.blockInterval' should be a 
positive value")
 
@@ -89,59 +108,100 @@ private[streaming] class BlockGenerator(
   private val blockPushingThread = new Thread() { override def run() { 
keepPushingBlocks() } }
 
   @volatile private var currentBuffer = new ArrayBuffer[Any]
-  @volatile private var stopped = false
+  @volatile private var state = Initialized
 
   /** Start block generating and pushing threads. */
-  def start() {
-    blockIntervalTimer.start()
-    blockPushingThread.start()
-    logInfo("Started BlockGenerator")
+  def start(): Unit = synchronized {
+    if (state == Initialized) {
+      state = Active
+      blockIntervalTimer.start()
+      blockPushingThread.start()
+      logInfo("Started BlockGenerator")
+    } else {
+      throw new SparkException(
+        s"Cannot start BlockGenerator as its not in the Initialized state 
[state = $state]")
+    }
   }
 
-  /** Stop all threads. */
-  def stop() {
+  /**
+   * Stop everything in the right order such that all the data added is pushed 
out correctly.
+   * - First, stop adding data to the current buffer.
+   * - Second, stop generating blocks.
+   * - Finally, wait for queue of to-be-pushed blocks to be drained.
+   */
+  def stop(): Unit = {
+    // Set the state to stop adding data
+    synchronized {
+      if (state == Active) {
+        state = StoppedAddingData
+      } else {
+        logWarning(s"Cannot stop BlockGenerator as its not in the Active state 
[state = $state]")
+        return
+      }
+    }
+
+    // Stop generating blocks and set the state for block pushing thread to 
start draining the queue
     logInfo("Stopping BlockGenerator")
     blockIntervalTimer.stop(interruptTimer = false)
-    stopped = true
-    logInfo("Waiting for block pushing thread")
+    synchronized { state = StoppedGeneratingBlocks }
+
+    // Wait for the queue to drain and mark generated as stopped
+    logInfo("Waiting for block pushing thread to terminate")
     blockPushingThread.join()
+    synchronized { state = StoppedAll }
     logInfo("Stopped BlockGenerator")
   }
 
   /**
-   * Push a single data item into the buffer. All received data items
-   * will be periodically pushed into BlockManager.
+   * Push a single data item into the buffer.
    */
-  def addData (data: Any): Unit = synchronized {
-    waitToPush()
-    currentBuffer += data
+  def addData(data: Any): Unit = synchronized {
+    if (state == Active) {
+      waitToPush()
+      currentBuffer += data
+    } else {
+      throw new SparkException(
+        "Cannot add data as BlockGenerator has not been started or has been 
stopped")
+    }
   }
 
   /**
    * Push a single data item into the buffer. After buffering the data, the
-   * `BlockGeneratorListener.onAddData` callback will be called. All received 
data items
-   * will be periodically pushed into BlockManager.
+   * `BlockGeneratorListener.onAddData` callback will be called.
    */
   def addDataWithCallback(data: Any, metadata: Any): Unit = synchronized {
-    waitToPush()
-    currentBuffer += data
-    listener.onAddData(data, metadata)
+    if (state == Active) {
+      waitToPush()
+      currentBuffer += data
+      listener.onAddData(data, metadata)
+    } else {
+      throw new SparkException(
+        "Cannot add data as BlockGenerator has not been started or has been 
stopped")
+    }
   }
 
   /**
    * Push multiple data items into the buffer. After buffering the data, the
-   * `BlockGeneratorListener.onAddData` callback will be called. All received 
data items
-   * will be periodically pushed into BlockManager. Note that all the data 
items is guaranteed
-   * to be present in a single block.
+   * `BlockGeneratorListener.onAddData` callback will be called. Note that all 
the data items
+   * are atomically added to the buffer, and are hence guaranteed to be 
present in a single block.
    */
   def addMultipleDataWithCallback(dataIterator: Iterator[Any], metadata: Any): 
Unit = synchronized {
-    dataIterator.foreach { data =>
-      waitToPush()
-      currentBuffer += data
+    if (state == Active) {
+      dataIterator.foreach { data =>
+        waitToPush()
+        currentBuffer += data
+      }
+      listener.onAddData(dataIterator, metadata)
+    } else {
+      throw new SparkException(
+        "Cannot add data as BlockGenerator has not been started or has been 
stopped")
     }
-    listener.onAddData(dataIterator, metadata)
   }
 
+  def isActive(): Boolean = state == Active
+
+  def isStopped(): Boolean = state == StoppedAll
+
   /** Change the buffer to which single records are added to. */
   private def updateCurrentBuffer(time: Long): Unit = synchronized {
     try {
@@ -165,18 +225,21 @@ private[streaming] class BlockGenerator(
   /** Keep pushing blocks to the BlockManager. */
   private def keepPushingBlocks() {
     logInfo("Started block pushing thread")
+
+    def isGeneratingBlocks = synchronized { state == Active || state == 
StoppedAddingData }
     try {
-      while (!stopped) {
-        Option(blocksForPushing.poll(100, TimeUnit.MILLISECONDS)) match {
+      while (isGeneratingBlocks) {
+        Option(blocksForPushing.poll(10, TimeUnit.MILLISECONDS)) match {
           case Some(block) => pushBlock(block)
           case None =>
         }
       }
-      // Push out the blocks that are still left
+
+      // At this point, state is StoppedGeneratingBlock. So drain the queue of 
to-be-pushed blocks.
       logInfo("Pushing out the last " + blocksForPushing.size() + " blocks")
       while (!blocksForPushing.isEmpty) {
-        logDebug("Getting block ")
         val block = blocksForPushing.take()
+        logDebug(s"Pushing block $block")
         pushBlock(block)
         logInfo("Blocks left to push " + blocksForPushing.size())
       }

http://git-wip-us.apache.org/repos/asf/spark/blob/0a078303/streaming/src/main/scala/org/apache/spark/streaming/receiver/RateLimiter.scala
----------------------------------------------------------------------
diff --git 
a/streaming/src/main/scala/org/apache/spark/streaming/receiver/RateLimiter.scala
 
b/streaming/src/main/scala/org/apache/spark/streaming/receiver/RateLimiter.scala
index f663def..bca1fbc 100644
--- 
a/streaming/src/main/scala/org/apache/spark/streaming/receiver/RateLimiter.scala
+++ 
b/streaming/src/main/scala/org/apache/spark/streaming/receiver/RateLimiter.scala
@@ -45,8 +45,7 @@ private[receiver] abstract class RateLimiter(conf: SparkConf) 
extends Logging {
   /**
    * Return the current rate limit. If no limit has been set so far, it 
returns {{{Long.MaxValue}}}.
    */
-  def getCurrentLimit: Long =
-    rateLimiter.getRate.toLong
+  def getCurrentLimit: Long = rateLimiter.getRate.toLong
 
   /**
    * Set the rate limit to `newRate`. The new rate will not exceed the maximum 
rate configured by

http://git-wip-us.apache.org/repos/asf/spark/blob/0a078303/streaming/src/main/scala/org/apache/spark/streaming/receiver/Receiver.scala
----------------------------------------------------------------------
diff --git 
a/streaming/src/main/scala/org/apache/spark/streaming/receiver/Receiver.scala 
b/streaming/src/main/scala/org/apache/spark/streaming/receiver/Receiver.scala
index 7504fa4..554aae0 100644
--- 
a/streaming/src/main/scala/org/apache/spark/streaming/receiver/Receiver.scala
+++ 
b/streaming/src/main/scala/org/apache/spark/streaming/receiver/Receiver.scala
@@ -116,12 +116,12 @@ abstract class Receiver[T](val storageLevel: 
StorageLevel) extends Serializable
    * being pushed into Spark's memory.
    */
   def store(dataItem: T) {
-    executor.pushSingle(dataItem)
+    supervisor.pushSingle(dataItem)
   }
 
   /** Store an ArrayBuffer of received data as a data block into Spark's 
memory. */
   def store(dataBuffer: ArrayBuffer[T]) {
-    executor.pushArrayBuffer(dataBuffer, None, None)
+    supervisor.pushArrayBuffer(dataBuffer, None, None)
   }
 
   /**
@@ -130,12 +130,12 @@ abstract class Receiver[T](val storageLevel: 
StorageLevel) extends Serializable
    * for being used in the corresponding InputDStream.
    */
   def store(dataBuffer: ArrayBuffer[T], metadata: Any) {
-    executor.pushArrayBuffer(dataBuffer, Some(metadata), None)
+    supervisor.pushArrayBuffer(dataBuffer, Some(metadata), None)
   }
 
   /** Store an iterator of received data as a data block into Spark's memory. 
*/
   def store(dataIterator: Iterator[T]) {
-    executor.pushIterator(dataIterator, None, None)
+    supervisor.pushIterator(dataIterator, None, None)
   }
 
   /**
@@ -144,12 +144,12 @@ abstract class Receiver[T](val storageLevel: 
StorageLevel) extends Serializable
    * for being used in the corresponding InputDStream.
    */
   def store(dataIterator: java.util.Iterator[T], metadata: Any) {
-    executor.pushIterator(dataIterator, Some(metadata), None)
+    supervisor.pushIterator(dataIterator, Some(metadata), None)
   }
 
   /** Store an iterator of received data as a data block into Spark's memory. 
*/
   def store(dataIterator: java.util.Iterator[T]) {
-    executor.pushIterator(dataIterator, None, None)
+    supervisor.pushIterator(dataIterator, None, None)
   }
 
   /**
@@ -158,7 +158,7 @@ abstract class Receiver[T](val storageLevel: StorageLevel) 
extends Serializable
    * for being used in the corresponding InputDStream.
    */
   def store(dataIterator: Iterator[T], metadata: Any) {
-    executor.pushIterator(dataIterator, Some(metadata), None)
+    supervisor.pushIterator(dataIterator, Some(metadata), None)
   }
 
   /**
@@ -167,7 +167,7 @@ abstract class Receiver[T](val storageLevel: StorageLevel) 
extends Serializable
    * that Spark is configured to use.
    */
   def store(bytes: ByteBuffer) {
-    executor.pushBytes(bytes, None, None)
+    supervisor.pushBytes(bytes, None, None)
   }
 
   /**
@@ -176,12 +176,12 @@ abstract class Receiver[T](val storageLevel: 
StorageLevel) extends Serializable
    * for being used in the corresponding InputDStream.
    */
   def store(bytes: ByteBuffer, metadata: Any) {
-    executor.pushBytes(bytes, Some(metadata), None)
+    supervisor.pushBytes(bytes, Some(metadata), None)
   }
 
   /** Report exceptions in receiving data. */
   def reportError(message: String, throwable: Throwable) {
-    executor.reportError(message, throwable)
+    supervisor.reportError(message, throwable)
   }
 
   /**
@@ -193,7 +193,7 @@ abstract class Receiver[T](val storageLevel: StorageLevel) 
extends Serializable
    * The `message` will be reported to the driver.
    */
   def restart(message: String) {
-    executor.restartReceiver(message)
+    supervisor.restartReceiver(message)
   }
 
   /**
@@ -205,7 +205,7 @@ abstract class Receiver[T](val storageLevel: StorageLevel) 
extends Serializable
    * The `message` and `exception` will be reported to the driver.
    */
   def restart(message: String, error: Throwable) {
-    executor.restartReceiver(message, Some(error))
+    supervisor.restartReceiver(message, Some(error))
   }
 
   /**
@@ -215,22 +215,22 @@ abstract class Receiver[T](val storageLevel: 
StorageLevel) extends Serializable
    * in a background thread.
    */
   def restart(message: String, error: Throwable, millisecond: Int) {
-    executor.restartReceiver(message, Some(error), millisecond)
+    supervisor.restartReceiver(message, Some(error), millisecond)
   }
 
   /** Stop the receiver completely. */
   def stop(message: String) {
-    executor.stop(message, None)
+    supervisor.stop(message, None)
   }
 
   /** Stop the receiver completely due to an exception */
   def stop(message: String, error: Throwable) {
-    executor.stop(message, Some(error))
+    supervisor.stop(message, Some(error))
   }
 
   /** Check if the receiver has started or not. */
   def isStarted(): Boolean = {
-    executor.isReceiverStarted()
+    supervisor.isReceiverStarted()
   }
 
   /**
@@ -238,7 +238,7 @@ abstract class Receiver[T](val storageLevel: StorageLevel) 
extends Serializable
    * the receiving of data should be stopped.
    */
   def isStopped(): Boolean = {
-    executor.isReceiverStopped()
+    supervisor.isReceiverStopped()
   }
 
   /**
@@ -257,7 +257,7 @@ abstract class Receiver[T](val storageLevel: StorageLevel) 
extends Serializable
   private var id: Int = -1
 
   /** Handler object that runs the receiver. This is instantiated lazily in 
the worker. */
-  private[streaming] var executor_ : ReceiverSupervisor = null
+  @transient private var _supervisor : ReceiverSupervisor = null
 
   /** Set the ID of the DStream that this receiver is associated with. */
   private[streaming] def setReceiverId(id_ : Int) {
@@ -265,15 +265,17 @@ abstract class Receiver[T](val storageLevel: 
StorageLevel) extends Serializable
   }
 
   /** Attach Network Receiver executor to this receiver. */
-  private[streaming] def attachExecutor(exec: ReceiverSupervisor) {
-    assert(executor_ == null)
-    executor_ = exec
+  private[streaming] def attachSupervisor(exec: ReceiverSupervisor) {
+    assert(_supervisor == null)
+    _supervisor = exec
   }
 
-  /** Get the attached executor. */
-  private def executor: ReceiverSupervisor = {
-    assert(executor_ != null, "Executor has not been attached to this 
receiver")
-    executor_
+  /** Get the attached supervisor. */
+  private[streaming] def supervisor: ReceiverSupervisor = {
+    assert(_supervisor != null,
+      "A ReceiverSupervisor have not been attached to the receiver yet. Maybe 
you are starting " +
+        "some computation in the receiver before the Receiver.onStart() has 
been called.")
+    _supervisor
   }
 }
 

http://git-wip-us.apache.org/repos/asf/spark/blob/0a078303/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceiverSupervisor.scala
----------------------------------------------------------------------
diff --git 
a/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceiverSupervisor.scala
 
b/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceiverSupervisor.scala
index e98017a..158d1ba 100644
--- 
a/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceiverSupervisor.scala
+++ 
b/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceiverSupervisor.scala
@@ -44,8 +44,8 @@ private[streaming] abstract class ReceiverSupervisor(
   }
   import ReceiverState._
 
-  // Attach the executor to the receiver
-  receiver.attachExecutor(this)
+  // Attach the supervisor to the receiver
+  receiver.attachSupervisor(this)
 
   private val futureExecutionContext = ExecutionContext.fromExecutorService(
     ThreadUtils.newDaemonCachedThreadPool("receiver-supervisor-future", 128))
@@ -60,7 +60,7 @@ private[streaming] abstract class ReceiverSupervisor(
   private val defaultRestartDelay = 
conf.getInt("spark.streaming.receiverRestartDelay", 2000)
 
   /** The current maximum rate limit for this receiver. */
-  private[streaming] def getCurrentRateLimit: Option[Long] = None
+  private[streaming] def getCurrentRateLimit: Long = Long.MaxValue
 
   /** Exception associated with the stopping of the receiver */
   @volatile protected var stoppingError: Throwable = null
@@ -92,13 +92,30 @@ private[streaming] abstract class ReceiverSupervisor(
       optionalBlockId: Option[StreamBlockId]
     )
 
+  /**
+   * Create a custom [[BlockGenerator]] that the receiver implementation can 
directly control
+   * using their provided [[BlockGeneratorListener]].
+   *
+   * Note: Do not explicitly start or stop the `BlockGenerator`, the 
`ReceiverSupervisorImpl`
+   * will take care of it.
+   */
+  def createBlockGenerator(blockGeneratorListener: BlockGeneratorListener): 
BlockGenerator
+
   /** Report errors. */
   def reportError(message: String, throwable: Throwable)
 
-  /** Called when supervisor is started */
+  /**
+   * Called when supervisor is started.
+   * Note that this must be called before the receiver.onStart() is called to 
ensure
+   * things like [[BlockGenerator]]s are started before the receiver starts 
sending data.
+   */
   protected def onStart() { }
 
-  /** Called when supervisor is stopped */
+  /**
+   * Called when supervisor is stopped.
+   * Note that this must be called after the receiver.onStop() is called to 
ensure
+   * things like [[BlockGenerator]]s are cleaned up after the receiver stops 
sending data.
+   */
   protected def onStop(message: String, error: Option[Throwable]) { }
 
   /** Called when receiver is started. Return true if the driver accepts us */

http://git-wip-us.apache.org/repos/asf/spark/blob/0a078303/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceiverSupervisorImpl.scala
----------------------------------------------------------------------
diff --git 
a/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceiverSupervisorImpl.scala
 
b/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceiverSupervisorImpl.scala
index 0d802f8..59ef58d 100644
--- 
a/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceiverSupervisorImpl.scala
+++ 
b/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceiverSupervisorImpl.scala
@@ -20,6 +20,7 @@ package org.apache.spark.streaming.receiver
 import java.nio.ByteBuffer
 import java.util.concurrent.atomic.AtomicLong
 
+import scala.collection.mutable
 import scala.collection.mutable.ArrayBuffer
 
 import com.google.common.base.Throwables
@@ -81,15 +82,20 @@ private[streaming] class ReceiverSupervisorImpl(
           cleanupOldBlocks(threshTime)
         case UpdateRateLimit(eps) =>
           logInfo(s"Received a new rate limit: $eps.")
-          blockGenerator.updateRate(eps)
+          registeredBlockGenerators.foreach { bg =>
+            bg.updateRate(eps)
+          }
       }
     })
 
   /** Unique block ids if one wants to add blocks directly */
   private val newBlockId = new AtomicLong(System.currentTimeMillis())
 
+  private val registeredBlockGenerators = new 
mutable.ArrayBuffer[BlockGenerator]
+    with mutable.SynchronizedBuffer[BlockGenerator]
+
   /** Divides received data records into data blocks for pushing in 
BlockManager. */
-  private val blockGenerator = new BlockGenerator(new BlockGeneratorListener {
+  private val defaultBlockGeneratorListener = new BlockGeneratorListener {
     def onAddData(data: Any, metadata: Any): Unit = { }
 
     def onGenerateBlock(blockId: StreamBlockId): Unit = { }
@@ -101,14 +107,15 @@ private[streaming] class ReceiverSupervisorImpl(
     def onPushBlock(blockId: StreamBlockId, arrayBuffer: ArrayBuffer[_]) {
       pushArrayBuffer(arrayBuffer, None, Some(blockId))
     }
-  }, streamId, env.conf)
+  }
+  private val defaultBlockGenerator = 
createBlockGenerator(defaultBlockGeneratorListener)
 
-  override private[streaming] def getCurrentRateLimit: Option[Long] =
-    Some(blockGenerator.getCurrentLimit)
+  /** Get the current rate limit of the default block generator */
+  override private[streaming] def getCurrentRateLimit: Long = 
defaultBlockGenerator.getCurrentLimit
 
   /** Push a single record of received data into block generator. */
   def pushSingle(data: Any) {
-    blockGenerator.addData(data)
+    defaultBlockGenerator.addData(data)
   }
 
   /** Store an ArrayBuffer of received data as a data block into Spark's 
memory. */
@@ -162,11 +169,11 @@ private[streaming] class ReceiverSupervisorImpl(
   }
 
   override protected def onStart() {
-    blockGenerator.start()
+    registeredBlockGenerators.foreach { _.start() }
   }
 
   override protected def onStop(message: String, error: Option[Throwable]) {
-    blockGenerator.stop()
+    registeredBlockGenerators.foreach { _.stop() }
     env.rpcEnv.stop(endpoint)
   }
 
@@ -183,6 +190,16 @@ private[streaming] class ReceiverSupervisorImpl(
     logInfo("Stopped receiver " + streamId)
   }
 
+  override def createBlockGenerator(
+      blockGeneratorListener: BlockGeneratorListener): BlockGenerator = {
+    // Cleanup BlockGenerators that have already been stopped
+    registeredBlockGenerators --= registeredBlockGenerators.filter{ 
_.isStopped() }
+
+    val newBlockGenerator = new BlockGenerator(blockGeneratorListener, 
streamId, env.conf)
+    registeredBlockGenerators += newBlockGenerator
+    newBlockGenerator
+  }
+
   /** Generate new block ID */
   private def nextBlockId = StreamBlockId(streamId, newBlockId.getAndIncrement)
 

http://git-wip-us.apache.org/repos/asf/spark/blob/0a078303/streaming/src/test/scala/org/apache/spark/streaming/CheckpointSuite.scala
----------------------------------------------------------------------
diff --git 
a/streaming/src/test/scala/org/apache/spark/streaming/CheckpointSuite.scala 
b/streaming/src/test/scala/org/apache/spark/streaming/CheckpointSuite.scala
index 67c2d90..1bba7a1 100644
--- a/streaming/src/test/scala/org/apache/spark/streaming/CheckpointSuite.scala
+++ b/streaming/src/test/scala/org/apache/spark/streaming/CheckpointSuite.scala
@@ -19,7 +19,7 @@ package org.apache.spark.streaming
 
 import java.io.File
 
-import scala.collection.mutable.{SynchronizedBuffer, ArrayBuffer}
+import scala.collection.mutable.{ArrayBuffer, SynchronizedBuffer}
 import scala.reflect.ClassTag
 
 import com.google.common.base.Charsets
@@ -33,7 +33,7 @@ import org.scalatest.concurrent.Eventually._
 import org.scalatest.time.SpanSugar._
 
 import org.apache.spark.streaming.dstream.{DStream, FileInputDStream}
-import org.apache.spark.streaming.scheduler.{RateLimitInputDStream, 
ConstantEstimator, SingletonTestRateReceiver}
+import org.apache.spark.streaming.scheduler.{ConstantEstimator, 
RateTestInputDStream, RateTestReceiver}
 import org.apache.spark.util.{Clock, ManualClock, Utils}
 
 /**
@@ -397,26 +397,24 @@ class CheckpointSuite extends TestSuiteBase {
     ssc = new StreamingContext(conf, batchDuration)
     ssc.checkpoint(checkpointDir)
 
-    val dstream = new RateLimitInputDStream(ssc) {
+    val dstream = new RateTestInputDStream(ssc) {
       override val rateController =
-        Some(new ReceiverRateController(id, new ConstantEstimator(200.0)))
+        Some(new ReceiverRateController(id, new ConstantEstimator(200)))
     }
-    SingletonTestRateReceiver.reset()
 
     val output = new 
TestOutputStreamWithPartitions(dstream.checkpoint(batchDuration * 2))
     output.register()
     runStreams(ssc, 5, 5)
 
-    SingletonTestRateReceiver.reset()
     ssc = new StreamingContext(checkpointDir)
     ssc.start()
     val outputNew = advanceTimeWithRealDelay(ssc, 2)
 
-    eventually(timeout(5.seconds)) {
-      assert(dstream.getCurrentRateLimit === Some(200))
+    eventually(timeout(10.seconds)) {
+      assert(RateTestReceiver.getActive().nonEmpty)
+      
assert(RateTestReceiver.getActive().get.getDefaultBlockGeneratorRateLimit() === 
200)
     }
     ssc.stop()
-    ssc = null
   }
 
   // This tests whether file input stream remembers what files were seen before

http://git-wip-us.apache.org/repos/asf/spark/blob/0a078303/streaming/src/test/scala/org/apache/spark/streaming/ReceiverSuite.scala
----------------------------------------------------------------------
diff --git 
a/streaming/src/test/scala/org/apache/spark/streaming/ReceiverSuite.scala 
b/streaming/src/test/scala/org/apache/spark/streaming/ReceiverSuite.scala
index 13b4d17..01279b3 100644
--- a/streaming/src/test/scala/org/apache/spark/streaming/ReceiverSuite.scala
+++ b/streaming/src/test/scala/org/apache/spark/streaming/ReceiverSuite.scala
@@ -129,32 +129,6 @@ class ReceiverSuite extends TestSuiteBase with Timeouts 
with Serializable {
     }
   }
 
-  test("block generator") {
-    val blockGeneratorListener = new FakeBlockGeneratorListener
-    val blockIntervalMs = 200
-    val conf = new SparkConf().set("spark.streaming.blockInterval", 
s"${blockIntervalMs}ms")
-    val blockGenerator = new BlockGenerator(blockGeneratorListener, 1, conf)
-    val expectedBlocks = 5
-    val waitTime = expectedBlocks * blockIntervalMs + (blockIntervalMs / 2)
-    val generatedData = new ArrayBuffer[Int]
-
-    // Generate blocks
-    val startTime = System.currentTimeMillis()
-    blockGenerator.start()
-    var count = 0
-    while(System.currentTimeMillis - startTime < waitTime) {
-      blockGenerator.addData(count)
-      generatedData += count
-      count += 1
-      Thread.sleep(10)
-    }
-    blockGenerator.stop()
-
-    val recordedData = blockGeneratorListener.arrayBuffers.flatten
-    assert(blockGeneratorListener.arrayBuffers.size > 0)
-    assert(recordedData.toSet === generatedData.toSet)
-  }
-
   ignore("block generator throttling") {
     val blockGeneratorListener = new FakeBlockGeneratorListener
     val blockIntervalMs = 100
@@ -348,6 +322,11 @@ class ReceiverSuite extends TestSuiteBase with Timeouts 
with Serializable {
     }
 
     override protected def onReceiverStart(): Boolean = true
+
+    override def createBlockGenerator(
+        blockGeneratorListener: BlockGeneratorListener): BlockGenerator = {
+      null
+    }
   }
 
   /**

http://git-wip-us.apache.org/repos/asf/spark/blob/0a078303/streaming/src/test/scala/org/apache/spark/streaming/receiver/BlockGeneratorSuite.scala
----------------------------------------------------------------------
diff --git 
a/streaming/src/test/scala/org/apache/spark/streaming/receiver/BlockGeneratorSuite.scala
 
b/streaming/src/test/scala/org/apache/spark/streaming/receiver/BlockGeneratorSuite.scala
new file mode 100644
index 0000000..a38cc60
--- /dev/null
+++ 
b/streaming/src/test/scala/org/apache/spark/streaming/receiver/BlockGeneratorSuite.scala
@@ -0,0 +1,253 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.streaming.receiver
+
+import scala.collection.mutable
+
+import org.scalatest.BeforeAndAfter
+import org.scalatest.Matchers._
+import org.scalatest.concurrent.Timeouts._
+import org.scalatest.concurrent.Eventually._
+import org.scalatest.time.SpanSugar._
+
+import org.apache.spark.storage.StreamBlockId
+import org.apache.spark.util.ManualClock
+import org.apache.spark.{SparkException, SparkConf, SparkFunSuite}
+
+class BlockGeneratorSuite extends SparkFunSuite with BeforeAndAfter {
+
+  private val blockIntervalMs = 10
+  private val conf = new SparkConf().set("spark.streaming.blockInterval", 
s"${blockIntervalMs}ms")
+  @volatile private var blockGenerator: BlockGenerator = null
+
+  after {
+    if (blockGenerator != null) {
+      blockGenerator.stop()
+    }
+  }
+
+  test("block generation and data callbacks") {
+    val listener = new TestBlockGeneratorListener
+    val clock = new ManualClock()
+
+    require(blockIntervalMs > 5)
+    require(listener.onAddDataCalled === false)
+    require(listener.onGenerateBlockCalled === false)
+    require(listener.onPushBlockCalled === false)
+
+    // Verify that creating the generator does not start it
+    blockGenerator = new BlockGenerator(listener, 0, conf, clock)
+    assert(blockGenerator.isActive() === false, "block generator active before 
start()")
+    assert(blockGenerator.isStopped() === false, "block generator stopped 
before start()")
+    assert(listener.onAddDataCalled === false)
+    assert(listener.onGenerateBlockCalled === false)
+    assert(listener.onPushBlockCalled === false)
+
+    // Verify start marks the generator active, but does not call the callbacks
+    blockGenerator.start()
+    assert(blockGenerator.isActive() === true, "block generator active after 
start()")
+    assert(blockGenerator.isStopped() === false, "block generator stopped 
after start()")
+    withClue("callbacks called before adding data") {
+      assert(listener.onAddDataCalled === false)
+      assert(listener.onGenerateBlockCalled === false)
+      assert(listener.onPushBlockCalled === false)
+    }
+
+    // Verify whether addData() adds data that is present in generated blocks
+    val data1 = 1 to 10
+    data1.foreach { blockGenerator.addData _ }
+    withClue("callbacks called on adding data without metadata and without 
block generation") {
+      assert(listener.onAddDataCalled === false) // should be called only with 
addDataWithCallback()
+      assert(listener.onGenerateBlockCalled === false)
+      assert(listener.onPushBlockCalled === false)
+    }
+    clock.advance(blockIntervalMs)  // advance clock to generate blocks
+    withClue("blocks not generated or pushed") {
+      eventually(timeout(1 second)) {
+        assert(listener.onGenerateBlockCalled === true)
+        assert(listener.onPushBlockCalled === true)
+      }
+    }
+    listener.pushedData should contain theSameElementsInOrderAs (data1)
+    assert(listener.onAddDataCalled === false) // should be called only with 
addDataWithCallback()
+
+    // Verify addDataWithCallback() add data+metadata and and callbacks are 
called correctly
+    val data2 = 11 to 20
+    val metadata2 = data2.map { _.toString }
+    data2.zip(metadata2).foreach { case (d, m) => 
blockGenerator.addDataWithCallback(d, m) }
+    assert(listener.onAddDataCalled === true)
+    listener.addedData should contain theSameElementsInOrderAs (data2)
+    listener.addedMetadata should contain theSameElementsInOrderAs (metadata2)
+    clock.advance(blockIntervalMs)  // advance clock to generate blocks
+    eventually(timeout(1 second)) {
+      listener.pushedData should contain theSameElementsInOrderAs (data1 ++ 
data2)
+    }
+
+    // Verify addMultipleDataWithCallback() add data+metadata and and 
callbacks are called correctly
+    val data3 = 21 to 30
+    val metadata3 = "metadata"
+    blockGenerator.addMultipleDataWithCallback(data3.iterator, metadata3)
+    listener.addedMetadata should contain theSameElementsInOrderAs (metadata2 
:+ metadata3)
+    clock.advance(blockIntervalMs)  // advance clock to generate blocks
+    eventually(timeout(1 second)) {
+      listener.pushedData should contain theSameElementsInOrderAs (data1 ++ 
data2 ++ data3)
+    }
+
+    // Stop the block generator by starting the stop on a different thread and
+    // then advancing the manual clock for the stopping to proceed.
+    val thread = stopBlockGenerator(blockGenerator)
+    eventually(timeout(1 second), interval(10 milliseconds)) {
+      clock.advance(blockIntervalMs)
+      assert(blockGenerator.isStopped() === true)
+    }
+    thread.join()
+
+    // Verify that the generator cannot be used any more
+    intercept[SparkException] {
+      blockGenerator.addData(1)
+    }
+    intercept[SparkException] {
+      blockGenerator.addDataWithCallback(1, 1)
+    }
+    intercept[SparkException] {
+      blockGenerator.addMultipleDataWithCallback(Iterator(1), 1)
+    }
+    intercept[SparkException] {
+      blockGenerator.start()
+    }
+    blockGenerator.stop()   // Calling stop again should be fine
+  }
+
+  test("stop ensures correct shutdown") {
+    val listener = new TestBlockGeneratorListener
+    val clock = new ManualClock()
+    blockGenerator = new BlockGenerator(listener, 0, conf, clock)
+    require(listener.onGenerateBlockCalled === false)
+    blockGenerator.start()
+    assert(blockGenerator.isActive() === true, "block generator")
+    assert(blockGenerator.isStopped() === false)
+
+    val data = 1 to 1000
+    data.foreach { blockGenerator.addData _ }
+
+    // Verify that stop() shutdowns everything in the right order
+    // - First, stop receiving new data
+    // - Second, wait for final block with all buffered data to be generated
+    // - Finally, wait for all blocks to be pushed
+    clock.advance(1) // to make sure that the timer for another interval to 
complete
+    val thread = stopBlockGenerator(blockGenerator)
+    eventually(timeout(1 second), interval(10 milliseconds)) {
+      assert(blockGenerator.isActive() === false)
+    }
+    assert(blockGenerator.isStopped() === false)
+
+    // Verify that data cannot be added
+    intercept[SparkException] {
+      blockGenerator.addData(1)
+    }
+    intercept[SparkException] {
+      blockGenerator.addDataWithCallback(1, null)
+    }
+    intercept[SparkException] {
+      blockGenerator.addMultipleDataWithCallback(Iterator(1), null)
+    }
+
+    // Verify that stop() stays blocked until another block containing all the 
data is generated
+    // This intercept always succeeds, as the body either will either throw a 
timeout exception
+    // (expected as stop() should never complete) or a SparkException 
(unexpected as stop()
+    // completed and thread terminated).
+    val exception = intercept[Exception] {
+      failAfter(200 milliseconds) {
+        thread.join()
+        throw new SparkException(
+          "BlockGenerator.stop() completed before generating timer was 
stopped")
+      }
+    }
+    exception should not be a [SparkException]
+
+
+    // Verify that the final data is present in the final generated block and
+    // pushed before complete stop
+    assert(blockGenerator.isStopped() === false) // generator has not stopped 
yet
+    clock.advance(blockIntervalMs)   // force block generation
+    failAfter(1 second) {
+      thread.join()
+    }
+    assert(blockGenerator.isStopped() === true) // generator has finally been 
completely stopped
+    assert(listener.pushedData === data, "All data not pushed by stop()")
+  }
+
+  test("block push errors are reported") {
+    val listener = new TestBlockGeneratorListener {
+      @volatile var errorReported = false
+      override def onPushBlock(
+          blockId: StreamBlockId, arrayBuffer: mutable.ArrayBuffer[_]): Unit = 
{
+        throw new SparkException("test")
+      }
+      override def onError(message: String, throwable: Throwable): Unit = {
+        errorReported = true
+      }
+    }
+    blockGenerator = new BlockGenerator(listener, 0, conf)
+    blockGenerator.start()
+    assert(listener.errorReported === false)
+    blockGenerator.addData(1)
+    eventually(timeout(1 second), interval(10 milliseconds)) {
+      assert(listener.errorReported === true)
+    }
+    blockGenerator.stop()
+  }
+
+  /**
+   * Helper method to stop the block generator with manual clock in a 
different thread,
+   * so that the main thread can advance the clock that allows the stopping to 
proceed.
+   */
+  private def stopBlockGenerator(blockGenerator: BlockGenerator): Thread = {
+    val thread = new Thread() {
+      override def run(): Unit = {
+        blockGenerator.stop()
+      }
+    }
+    thread.start()
+    thread
+  }
+
+  /** A listener for BlockGenerator that records the data in the callbacks */
+  private class TestBlockGeneratorListener extends BlockGeneratorListener {
+    val pushedData = new mutable.ArrayBuffer[Any] with 
mutable.SynchronizedBuffer[Any]
+    val addedData = new mutable.ArrayBuffer[Any] with 
mutable.SynchronizedBuffer[Any]
+    val addedMetadata = new mutable.ArrayBuffer[Any] with 
mutable.SynchronizedBuffer[Any]
+    @volatile var onGenerateBlockCalled = false
+    @volatile var onAddDataCalled = false
+    @volatile var onPushBlockCalled = false
+
+    override def onPushBlock(blockId: StreamBlockId, arrayBuffer: 
mutable.ArrayBuffer[_]): Unit = {
+      pushedData ++= arrayBuffer
+      onPushBlockCalled = true
+    }
+    override def onError(message: String, throwable: Throwable): Unit = {}
+    override def onGenerateBlock(blockId: StreamBlockId): Unit = {
+      onGenerateBlockCalled = true
+    }
+    override def onAddData(data: Any, metadata: Any): Unit = {
+      addedData += data
+      addedMetadata += metadata
+      onAddDataCalled = true
+    }
+  }
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/0a078303/streaming/src/test/scala/org/apache/spark/streaming/scheduler/RateControllerSuite.scala
----------------------------------------------------------------------
diff --git 
a/streaming/src/test/scala/org/apache/spark/streaming/scheduler/RateControllerSuite.scala
 
b/streaming/src/test/scala/org/apache/spark/streaming/scheduler/RateControllerSuite.scala
index 921da77..1eb52b7 100644
--- 
a/streaming/src/test/scala/org/apache/spark/streaming/scheduler/RateControllerSuite.scala
+++ 
b/streaming/src/test/scala/org/apache/spark/streaming/scheduler/RateControllerSuite.scala
@@ -18,10 +18,7 @@
 package org.apache.spark.streaming.scheduler
 
 import scala.collection.mutable
-import scala.reflect.ClassTag
-import scala.util.control.NonFatal
 
-import org.scalatest.Matchers._
 import org.scalatest.concurrent.Eventually._
 import org.scalatest.time.SpanSugar._
 
@@ -32,72 +29,63 @@ class RateControllerSuite extends TestSuiteBase {
 
   override def useManualClock: Boolean = false
 
-  test("rate controller publishes updates") {
+  override def batchDuration: Duration = Milliseconds(50)
+
+  test("RateController - rate controller publishes updates after batches 
complete") {
     val ssc = new StreamingContext(conf, batchDuration)
     withStreamingContext(ssc) { ssc =>
-      val dstream = new RateLimitInputDStream(ssc)
+      val dstream = new RateTestInputDStream(ssc)
       dstream.register()
       ssc.start()
 
       eventually(timeout(10.seconds)) {
-        assert(dstream.publishCalls > 0)
+        assert(dstream.publishedRates > 0)
       }
     }
   }
 
-  test("publish rates reach receivers") {
+  test("ReceiverRateController - published rates reach receivers") {
     val ssc = new StreamingContext(conf, batchDuration)
     withStreamingContext(ssc) { ssc =>
-      val dstream = new RateLimitInputDStream(ssc) {
+      val estimator = new ConstantEstimator(100)
+      val dstream = new RateTestInputDStream(ssc) {
         override val rateController =
-          Some(new ReceiverRateController(id, new ConstantEstimator(200.0)))
+          Some(new ReceiverRateController(id, estimator))
       }
       dstream.register()
-      SingletonTestRateReceiver.reset()
       ssc.start()
 
-      eventually(timeout(10.seconds)) {
-        assert(dstream.getCurrentRateLimit === Some(200))
+      // Wait for receiver to start
+      eventually(timeout(5.seconds)) {
+        RateTestReceiver.getActive().nonEmpty
       }
-    }
-  }
 
-  test("multiple publish rates reach receivers") {
-    val ssc = new StreamingContext(conf, batchDuration)
-    withStreamingContext(ssc) { ssc =>
-      val rates = Seq(100L, 200L, 300L)
-
-      val dstream = new RateLimitInputDStream(ssc) {
-        override val rateController =
-          Some(new ReceiverRateController(id, new 
ConstantEstimator(rates.map(_.toDouble): _*)))
+      // Update rate in the estimator and verify whether the rate was 
published to the receiver
+      def updateRateAndVerify(rate: Long): Unit = {
+        estimator.updateRate(rate)
+        eventually(timeout(5.seconds)) {
+          
assert(RateTestReceiver.getActive().get.getDefaultBlockGeneratorRateLimit() === 
rate)
+        }
       }
-      SingletonTestRateReceiver.reset()
-      dstream.register()
-
-      val observedRates = mutable.HashSet.empty[Long]
-      ssc.start()
 
-      eventually(timeout(20.seconds)) {
-        dstream.getCurrentRateLimit.foreach(observedRates += _)
-        // Long.MaxValue (essentially, no rate limit) is the initial rate 
limit for any Receiver
-        observedRates should contain theSameElementsAs (rates :+ Long.MaxValue)
+      // Verify multiple rate update
+      Seq(100, 200, 300).foreach { rate =>
+        updateRateAndVerify(rate)
       }
     }
   }
 }
 
-private[streaming] class ConstantEstimator(rates: Double*) extends 
RateEstimator {
-  private var idx: Int = 0
+private[streaming] class ConstantEstimator(@volatile private var rate: Long)
+  extends RateEstimator {
 
-  private def nextRate(): Double = {
-    val rate = rates(idx)
-    idx = (idx + 1) % rates.size
-    rate
+  def updateRate(newRate: Long): Unit = {
+    rate = newRate
   }
 
   def compute(
       time: Long,
       elements: Long,
       processingDelay: Long,
-      schedulingDelay: Long): Option[Double] = Some(nextRate())
+      schedulingDelay: Long): Option[Double] = Some(rate)
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/0a078303/streaming/src/test/scala/org/apache/spark/streaming/scheduler/ReceiverTrackerSuite.scala
----------------------------------------------------------------------
diff --git 
a/streaming/src/test/scala/org/apache/spark/streaming/scheduler/ReceiverTrackerSuite.scala
 
b/streaming/src/test/scala/org/apache/spark/streaming/scheduler/ReceiverTrackerSuite.scala
index afad5f1..dd292ba 100644
--- 
a/streaming/src/test/scala/org/apache/spark/streaming/scheduler/ReceiverTrackerSuite.scala
+++ 
b/streaming/src/test/scala/org/apache/spark/streaming/scheduler/ReceiverTrackerSuite.scala
@@ -17,48 +17,43 @@
 
 package org.apache.spark.streaming.scheduler
 
+import scala.collection.mutable.ArrayBuffer
+
 import org.scalatest.concurrent.Eventually._
 import org.scalatest.time.SpanSugar._
 
-import org.apache.spark.SparkConf
+import org.apache.spark.storage.{StorageLevel, StreamBlockId}
 import org.apache.spark.streaming._
-import org.apache.spark.streaming.receiver._
 import org.apache.spark.streaming.dstream.ReceiverInputDStream
-import org.apache.spark.storage.StorageLevel
+import org.apache.spark.streaming.receiver._
 
 /** Testsuite for receiver scheduling */
 class ReceiverTrackerSuite extends TestSuiteBase {
-  val sparkConf = new SparkConf().setMaster("local[8]").setAppName("test")
-
-  test("Receiver tracker - propagates rate limit") {
-    withStreamingContext(new StreamingContext(sparkConf, Milliseconds(100))) { 
ssc =>
-      object ReceiverStartedWaiter extends StreamingListener {
-        @volatile
-        var started = false
-
-        override def onReceiverStarted(receiverStarted: 
StreamingListenerReceiverStarted): Unit = {
-          started = true
-        }
-      }
 
-      ssc.addStreamingListener(ReceiverStartedWaiter)
+  test("send rate update to receivers") {
+    withStreamingContext(new StreamingContext(conf, Milliseconds(100))) { ssc 
=>
       ssc.scheduler.listenerBus.start(ssc.sc)
-      SingletonTestRateReceiver.reset()
 
       val newRateLimit = 100L
-      val inputDStream = new RateLimitInputDStream(ssc)
+      val inputDStream = new RateTestInputDStream(ssc)
       val tracker = new ReceiverTracker(ssc)
       tracker.start()
       try {
         // we wait until the Receiver has registered with the tracker,
         // otherwise our rate update is lost
         eventually(timeout(5 seconds)) {
-          assert(ReceiverStartedWaiter.started)
+          assert(RateTestReceiver.getActive().nonEmpty)
         }
+
+
+        // Verify that the rate of the block generator in the receiver get 
updated
+        val activeReceiver = RateTestReceiver.getActive().get
         tracker.sendRateUpdate(inputDStream.id, newRateLimit)
-        // this is an async message, we need to wait a bit for it to be 
processed
-        eventually(timeout(3 seconds)) {
-          assert(inputDStream.getCurrentRateLimit.get === newRateLimit)
+        eventually(timeout(5 seconds)) {
+          assert(activeReceiver.getDefaultBlockGeneratorRateLimit() === 
newRateLimit,
+            "default block generator did not receive rate update")
+          assert(activeReceiver.getCustomBlockGeneratorRateLimit() === 
newRateLimit,
+            "other block generator did not receive rate update")
         }
       } finally {
         tracker.stop(false)
@@ -67,69 +62,73 @@ class ReceiverTrackerSuite extends TestSuiteBase {
   }
 }
 
-/**
- * An input DStream with a hard-coded receiver that gives access to internals 
for testing.
- *
- * @note Make sure to call {{{SingletonDummyReceiver.reset()}}} before using 
this in a test,
- *       or otherwise you may get {{{NotSerializableException}}} when trying 
to serialize
- *       the receiver.
- * @see [[[SingletonDummyReceiver]]].
- */
-private[streaming] class RateLimitInputDStream(@transient ssc_ : 
StreamingContext)
+/** An input DStream with for testing rate controlling */
+private[streaming] class RateTestInputDStream(@transient ssc_ : 
StreamingContext)
   extends ReceiverInputDStream[Int](ssc_) {
 
-  override def getReceiver(): RateTestReceiver = SingletonTestRateReceiver
-
-  def getCurrentRateLimit: Option[Long] = {
-    invokeExecutorMethod.getCurrentRateLimit
-  }
+  override def getReceiver(): Receiver[Int] = new RateTestReceiver(id)
 
   @volatile
-  var publishCalls = 0
+  var publishedRates = 0
 
   override val rateController: Option[RateController] = {
-    Some(new RateController(id, new ConstantEstimator(100.0)) {
+    Some(new RateController(id, new ConstantEstimator(100)) {
       override def publish(rate: Long): Unit = {
-        publishCalls += 1
+        publishedRates += 1
       }
     })
   }
+}
 
-  private def invokeExecutorMethod: ReceiverSupervisor = {
-    val c = classOf[Receiver[_]]
-    val ex = c.getDeclaredMethod("executor")
-    ex.setAccessible(true)
-    ex.invoke(SingletonTestRateReceiver).asInstanceOf[ReceiverSupervisor]
+/** A receiver implementation for testing rate controlling */
+private[streaming] class RateTestReceiver(receiverId: Int, host: 
Option[String] = None)
+  extends Receiver[Int](StorageLevel.MEMORY_ONLY) {
+
+  private lazy val customBlockGenerator = supervisor.createBlockGenerator(
+    new BlockGeneratorListener {
+      override def onPushBlock(blockId: StreamBlockId, arrayBuffer: 
ArrayBuffer[_]): Unit = {}
+      override def onError(message: String, throwable: Throwable): Unit = {}
+      override def onGenerateBlock(blockId: StreamBlockId): Unit = {}
+      override def onAddData(data: Any, metadata: Any): Unit = {}
+    }
+  )
+
+  setReceiverId(receiverId)
+
+  override def onStart(): Unit = {
+    customBlockGenerator
+    RateTestReceiver.registerReceiver(this)
   }
-}
 
-/**
- * A Receiver as an object so we can read its rate limit. Make sure to call 
`reset()` when
- * reusing this receiver, otherwise a non-null `executor_` field will prevent 
it from being
- * serialized when receivers are installed on executors.
- *
- * @note It's necessary to be a top-level object, or else serialization would 
create another
- *       one on the executor side and we won't be able to read its rate limit.
- */
-private[streaming] object SingletonTestRateReceiver extends 
RateTestReceiver(0) {
+  override def onStop(): Unit = {
+    RateTestReceiver.deregisterReceiver()
+  }
+
+  override def preferredLocation: Option[String] = host
 
-  /** Reset the object to be usable in another test. */
-  def reset(): Unit = {
-    executor_ = null
+  def getDefaultBlockGeneratorRateLimit(): Long = {
+    supervisor.getCurrentRateLimit
+  }
+
+  def getCustomBlockGeneratorRateLimit(): Long = {
+    customBlockGenerator.getCurrentLimit
   }
 }
 
 /**
- * Dummy receiver implementation
+ * A helper object to RateTestReceiver that give access to the currently 
active RateTestReceiver
+ * instance.
  */
-private[streaming] class RateTestReceiver(receiverId: Int, host: 
Option[String] = None)
-  extends Receiver[Int](StorageLevel.MEMORY_ONLY) {
+private[streaming] object RateTestReceiver {
+  @volatile private var activeReceiver: RateTestReceiver = null
 
-  setReceiverId(receiverId)
-
-  override def onStart(): Unit = {}
+  def registerReceiver(receiver: RateTestReceiver): Unit = {
+    activeReceiver = receiver
+  }
 
-  override def onStop(): Unit = {}
+  def deregisterReceiver(): Unit = {
+    activeReceiver = null
+  }
 
-  override def preferredLocation: Option[String] = host
+  def getActive(): Option[RateTestReceiver] = Option(activeReceiver)
 }


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to