Repository: spark
Updated Branches:
  refs/heads/master 069a4c414 -> 819be46e5


[SPARK-8977] [STREAMING] Defines the RateEstimator interface, and impements the 
RateController

Based on #7471.

- [x] add a test that exercises the publish path from driver to receiver
- [ ] remove Serializable from `RateController` and `RateEstimator`

Author: Iulian Dragos <jagua...@gmail.com>
Author: François Garillot <franc...@garillot.net>

Closes #7600 from dragos/topic/streaming-bp/rate-controller and squashes the 
following commits:

f168c94 [Iulian Dragos] Latest review round.
5125e60 [Iulian Dragos] Fix style.
a2eb3b9 [Iulian Dragos] Merge remote-tracking branch 'upstream/master' into 
topic/streaming-bp/rate-controller
475e346 [Iulian Dragos] Latest round of reviews.
e9fb45e [Iulian Dragos] - Add a test for checkpointing - fixed serialization 
for RateController.executionContext
715437a [Iulian Dragos] Review comments and added a `reset` call in 
ReceiverTrackerTest.
e57c66b [Iulian Dragos] Added a couple of tests for the full scenario from 
driver to receivers, with several rate updates.
b425d32 [Iulian Dragos] Removed DeveloperAPI, removed rateEstimator field, 
removed Noop rate estimator, changed logic for initialising rate estimator.
238cfc6 [Iulian Dragos] Merge remote-tracking branch 'upstream/master' into 
topic/streaming-bp/rate-controller
34a389d [Iulian Dragos] Various style changes and a first test for the rate 
controller.
d32ca36 [François Garillot] [SPARK-8977][Streaming] Defines the RateEstimator 
interface, and implements the ReceiverRateController
8941cf9 [Iulian Dragos] Renames and other nitpicks.
162d9e5 [Iulian Dragos] Use Reflection for accessing truly private `executor` 
method and use the listener bus to know when receivers have registered 
(`onStart` is called before receivers have registered, leading to flaky 
behavior).
210f495 [Iulian Dragos] Revert "Added a few tests that measure the receiver’s 
rate."
0c51959 [Iulian Dragos] Added a few tests that measure the receiver’s rate.
261a051 [Iulian Dragos] - removed field to hold the current rate limit in rate 
limiter - made rate limit a Long and default to Long.MaxValue (consequence of 
the above) - removed custom `waitUntil` and replaced it by `eventually`
cd1397d [Iulian Dragos] Add a test for the propagation of a new rate limit from 
driver to receivers.
6369b30 [Iulian Dragos] Merge pull request #15 from huitseeker/SPARK-8975
d15de42 [François Garillot] [SPARK-8975][Streaming] Adds Ratelimiter unit 
tests w.r.t. spark.streaming.receiver.maxRate
4721c7d [François Garillot] [SPARK-8975][Streaming] Add a mechanism to send a 
new rate from the driver to the block generator


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/819be46e
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/819be46e
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/819be46e

Branch: refs/heads/master
Commit: 819be46e5a73f2d19230354ebba30c58538590f5
Parents: 069a4c4
Author: Iulian Dragos <jagua...@gmail.com>
Authored: Wed Jul 29 13:47:37 2015 -0700
Committer: Tathagata Das <tathagata.das1...@gmail.com>
Committed: Wed Jul 29 13:47:37 2015 -0700

----------------------------------------------------------------------
 .../spark/streaming/dstream/InputDStream.scala  |   7 +-
 .../dstream/ReceiverInputDStream.scala          |  26 ++++-
 .../streaming/scheduler/JobScheduler.scala      |   6 ++
 .../streaming/scheduler/RateController.scala    |  90 ++++++++++++++++
 .../scheduler/rate/RateEstimator.scala          |  59 +++++++++++
 .../spark/streaming/CheckpointSuite.scala       |  28 +++++
 .../scheduler/RateControllerSuite.scala         | 103 +++++++++++++++++++
 .../ReceiverSchedulingPolicySuite.scala         |  10 +-
 .../scheduler/ReceiverTrackerSuite.scala        |  41 ++++++--
 9 files changed, 355 insertions(+), 15 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/819be46e/streaming/src/main/scala/org/apache/spark/streaming/dstream/InputDStream.scala
----------------------------------------------------------------------
diff --git 
a/streaming/src/main/scala/org/apache/spark/streaming/dstream/InputDStream.scala
 
b/streaming/src/main/scala/org/apache/spark/streaming/dstream/InputDStream.scala
index d58c99a..a6c4cd2 100644
--- 
a/streaming/src/main/scala/org/apache/spark/streaming/dstream/InputDStream.scala
+++ 
b/streaming/src/main/scala/org/apache/spark/streaming/dstream/InputDStream.scala
@@ -21,7 +21,9 @@ import scala.reflect.ClassTag
 
 import org.apache.spark.SparkContext
 import org.apache.spark.rdd.RDDOperationScope
-import org.apache.spark.streaming.{Time, Duration, StreamingContext}
+import org.apache.spark.streaming.{Duration, StreamingContext, Time}
+import org.apache.spark.streaming.scheduler.RateController
+import org.apache.spark.streaming.scheduler.rate.RateEstimator
 import org.apache.spark.util.Utils
 
 /**
@@ -47,6 +49,9 @@ abstract class InputDStream[T: ClassTag] (@transient ssc_ : 
StreamingContext)
   /** This is an unique identifier for the input stream. */
   val id = ssc.getNewInputStreamId()
 
+  // Keep track of the freshest rate for this stream using the rateEstimator
+  protected[streaming] val rateController: Option[RateController] = None
+
   /** A human-readable name of this InputDStream */
   private[streaming] def name: String = {
     // e.g. FlumePollingDStream -> "Flume polling stream"

http://git-wip-us.apache.org/repos/asf/spark/blob/819be46e/streaming/src/main/scala/org/apache/spark/streaming/dstream/ReceiverInputDStream.scala
----------------------------------------------------------------------
diff --git 
a/streaming/src/main/scala/org/apache/spark/streaming/dstream/ReceiverInputDStream.scala
 
b/streaming/src/main/scala/org/apache/spark/streaming/dstream/ReceiverInputDStream.scala
index a50f0ef..646a8c3 100644
--- 
a/streaming/src/main/scala/org/apache/spark/streaming/dstream/ReceiverInputDStream.scala
+++ 
b/streaming/src/main/scala/org/apache/spark/streaming/dstream/ReceiverInputDStream.scala
@@ -21,10 +21,11 @@ import scala.reflect.ClassTag
 
 import org.apache.spark.rdd.{BlockRDD, RDD}
 import org.apache.spark.storage.BlockId
-import org.apache.spark.streaming._
+import org.apache.spark.streaming.{StreamingContext, Time}
 import org.apache.spark.streaming.rdd.WriteAheadLogBackedBlockRDD
 import org.apache.spark.streaming.receiver.Receiver
-import org.apache.spark.streaming.scheduler.StreamInputInfo
+import org.apache.spark.streaming.scheduler.{RateController, StreamInputInfo}
+import org.apache.spark.streaming.scheduler.rate.RateEstimator
 import org.apache.spark.streaming.util.WriteAheadLogUtils
 
 /**
@@ -41,6 +42,17 @@ abstract class ReceiverInputDStream[T: ClassTag](@transient 
ssc_ : StreamingCont
   extends InputDStream[T](ssc_) {
 
   /**
+   * Asynchronously maintains & sends new rate limits to the receiver through 
the receiver tracker.
+   */
+  override protected[streaming] val rateController: Option[RateController] = {
+    if (RateController.isBackPressureEnabled(ssc.conf)) {
+      RateEstimator.create(ssc.conf).map { new ReceiverRateController(id, _) }
+    } else {
+      None
+    }
+  }
+
+  /**
    * Gets the receiver object that will be sent to the worker nodes
    * to receive data. This method needs to defined by any specific 
implementation
    * of a ReceiverInputDStream.
@@ -110,4 +122,14 @@ abstract class ReceiverInputDStream[T: 
ClassTag](@transient ssc_ : StreamingCont
     }
     Some(blockRDD)
   }
+
+  /**
+   * A RateController that sends the new rate to receivers, via the receiver 
tracker.
+   */
+  private[streaming] class ReceiverRateController(id: Int, estimator: 
RateEstimator)
+      extends RateController(id, estimator) {
+    override def publish(rate: Long): Unit =
+      ssc.scheduler.receiverTracker.sendRateUpdate(id, rate)
+  }
 }
+

http://git-wip-us.apache.org/repos/asf/spark/blob/819be46e/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobScheduler.scala
----------------------------------------------------------------------
diff --git 
a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobScheduler.scala
 
b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobScheduler.scala
index 4af9b6d..58bdda7 100644
--- 
a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobScheduler.scala
+++ 
b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobScheduler.scala
@@ -66,6 +66,12 @@ class JobScheduler(val ssc: StreamingContext) extends 
Logging {
     }
     eventLoop.start()
 
+    // attach rate controllers of input streams to receive batch completion 
updates
+    for {
+      inputDStream <- ssc.graph.getInputStreams
+      rateController <- inputDStream.rateController
+    } ssc.addStreamingListener(rateController)
+
     listenerBus.start(ssc.sparkContext)
     receiverTracker = new ReceiverTracker(ssc)
     inputInfoTracker = new InputInfoTracker(ssc)

http://git-wip-us.apache.org/repos/asf/spark/blob/819be46e/streaming/src/main/scala/org/apache/spark/streaming/scheduler/RateController.scala
----------------------------------------------------------------------
diff --git 
a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/RateController.scala
 
b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/RateController.scala
new file mode 100644
index 0000000..882ca06
--- /dev/null
+++ 
b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/RateController.scala
@@ -0,0 +1,90 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.streaming.scheduler
+
+import java.io.ObjectInputStream
+import java.util.concurrent.atomic.AtomicLong
+
+import scala.concurrent.{ExecutionContext, Future}
+
+import org.apache.spark.SparkConf
+import org.apache.spark.streaming.scheduler.rate.RateEstimator
+import org.apache.spark.util.{ThreadUtils, Utils}
+
+/**
+ * A StreamingListener that receives batch completion updates, and maintains
+ * an estimate of the speed at which this stream should ingest messages,
+ * given an estimate computation from a `RateEstimator`
+ */
+private[streaming] abstract class RateController(val streamUID: Int, 
rateEstimator: RateEstimator)
+    extends StreamingListener with Serializable {
+
+  init()
+
+  protected def publish(rate: Long): Unit
+
+  @transient
+  implicit private var executionContext: ExecutionContext = _
+
+  @transient
+  private var rateLimit: AtomicLong = _
+
+  /**
+   * An initialization method called both from the constructor and 
Serialization code.
+   */
+  private def init() {
+    executionContext = ExecutionContext.fromExecutorService(
+      ThreadUtils.newDaemonSingleThreadExecutor("stream-rate-update"))
+    rateLimit = new AtomicLong(-1L)
+  }
+
+  private def readObject(ois: ObjectInputStream): Unit = 
Utils.tryOrIOException {
+    ois.defaultReadObject()
+    init()
+  }
+
+  /**
+   * Compute the new rate limit and publish it asynchronously.
+   */
+  private def computeAndPublish(time: Long, elems: Long, workDelay: Long, 
waitDelay: Long): Unit =
+    Future[Unit] {
+      val newRate = rateEstimator.compute(time, elems, workDelay, waitDelay)
+      newRate.foreach { s =>
+        rateLimit.set(s.toLong)
+        publish(getLatestRate())
+      }
+    }
+
+  def getLatestRate(): Long = rateLimit.get()
+
+  override def onBatchCompleted(batchCompleted: 
StreamingListenerBatchCompleted) {
+    val elements = batchCompleted.batchInfo.streamIdToInputInfo
+
+    for {
+      processingEnd <- batchCompleted.batchInfo.processingEndTime;
+      workDelay <- batchCompleted.batchInfo.processingDelay;
+      waitDelay <- batchCompleted.batchInfo.schedulingDelay;
+      elems <- elements.get(streamUID).map(_.numRecords)
+    } computeAndPublish(processingEnd, elems, workDelay, waitDelay)
+  }
+}
+
+object RateController {
+  def isBackPressureEnabled(conf: SparkConf): Boolean =
+    conf.getBoolean("spark.streaming.backpressure.enable", false)
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/819be46e/streaming/src/main/scala/org/apache/spark/streaming/scheduler/rate/RateEstimator.scala
----------------------------------------------------------------------
diff --git 
a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/rate/RateEstimator.scala
 
b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/rate/RateEstimator.scala
new file mode 100644
index 0000000..a086851
--- /dev/null
+++ 
b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/rate/RateEstimator.scala
@@ -0,0 +1,59 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.streaming.scheduler.rate
+
+import org.apache.spark.SparkConf
+import org.apache.spark.SparkException
+
+/**
+ * A component that estimates the rate at wich an InputDStream should ingest
+ * elements, based on updates at every batch completion.
+ */
+private[streaming] trait RateEstimator extends Serializable {
+
+  /**
+   * Computes the number of elements the stream attached to this 
`RateEstimator`
+   * should ingest per second, given an update on the size and completion
+   * times of the latest batch.
+   *
+   * @param time The timetamp of the current batch interval that just finished
+   * @param elements The number of elements that were processed in this batch
+   * @param processingDelay The time in ms that took for the job to complete
+   * @param schedulingDelay The time in ms that the job spent in the 
scheduling queue
+   */
+  def compute(
+      time: Long,
+      elements: Long,
+      processingDelay: Long,
+      schedulingDelay: Long): Option[Double]
+}
+
+object RateEstimator {
+
+  /**
+   * Return a new RateEstimator based on the value of 
`spark.streaming.RateEstimator`.
+   *
+   * @return None if there is no configured estimator, otherwise an instance 
of RateEstimator
+   * @throws IllegalArgumentException if there is a configured RateEstimator 
that doesn't match any
+   *         known estimators.
+   */
+  def create(conf: SparkConf): Option[RateEstimator] =
+    conf.getOption("spark.streaming.backpressure.rateEstimator").map { 
estimator =>
+      throw new IllegalArgumentException(s"Unkown rate estimator: $estimator")
+    }
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/819be46e/streaming/src/test/scala/org/apache/spark/streaming/CheckpointSuite.scala
----------------------------------------------------------------------
diff --git 
a/streaming/src/test/scala/org/apache/spark/streaming/CheckpointSuite.scala 
b/streaming/src/test/scala/org/apache/spark/streaming/CheckpointSuite.scala
index d308ac0..67c2d90 100644
--- a/streaming/src/test/scala/org/apache/spark/streaming/CheckpointSuite.scala
+++ b/streaming/src/test/scala/org/apache/spark/streaming/CheckpointSuite.scala
@@ -30,8 +30,10 @@ import org.apache.hadoop.io.{IntWritable, Text}
 import org.apache.hadoop.mapred.TextOutputFormat
 import org.apache.hadoop.mapreduce.lib.output.{TextOutputFormat => 
NewTextOutputFormat}
 import org.scalatest.concurrent.Eventually._
+import org.scalatest.time.SpanSugar._
 
 import org.apache.spark.streaming.dstream.{DStream, FileInputDStream}
+import org.apache.spark.streaming.scheduler.{RateLimitInputDStream, 
ConstantEstimator, SingletonTestRateReceiver}
 import org.apache.spark.util.{Clock, ManualClock, Utils}
 
 /**
@@ -391,6 +393,32 @@ class CheckpointSuite extends TestSuiteBase {
     testCheckpointedOperation(input, operation, output, 7)
   }
 
+  test("recovery maintains rate controller") {
+    ssc = new StreamingContext(conf, batchDuration)
+    ssc.checkpoint(checkpointDir)
+
+    val dstream = new RateLimitInputDStream(ssc) {
+      override val rateController =
+        Some(new ReceiverRateController(id, new ConstantEstimator(200.0)))
+    }
+    SingletonTestRateReceiver.reset()
+
+    val output = new 
TestOutputStreamWithPartitions(dstream.checkpoint(batchDuration * 2))
+    output.register()
+    runStreams(ssc, 5, 5)
+
+    SingletonTestRateReceiver.reset()
+    ssc = new StreamingContext(checkpointDir)
+    ssc.start()
+    val outputNew = advanceTimeWithRealDelay(ssc, 2)
+
+    eventually(timeout(5.seconds)) {
+      assert(dstream.getCurrentRateLimit === Some(200))
+    }
+    ssc.stop()
+    ssc = null
+  }
+
   // This tests whether file input stream remembers what files were seen before
   // the master failure and uses them again to process a large window 
operation.
   // It also tests whether batches, whose processing was incomplete due to the

http://git-wip-us.apache.org/repos/asf/spark/blob/819be46e/streaming/src/test/scala/org/apache/spark/streaming/scheduler/RateControllerSuite.scala
----------------------------------------------------------------------
diff --git 
a/streaming/src/test/scala/org/apache/spark/streaming/scheduler/RateControllerSuite.scala
 
b/streaming/src/test/scala/org/apache/spark/streaming/scheduler/RateControllerSuite.scala
new file mode 100644
index 0000000..921da77
--- /dev/null
+++ 
b/streaming/src/test/scala/org/apache/spark/streaming/scheduler/RateControllerSuite.scala
@@ -0,0 +1,103 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.streaming.scheduler
+
+import scala.collection.mutable
+import scala.reflect.ClassTag
+import scala.util.control.NonFatal
+
+import org.scalatest.Matchers._
+import org.scalatest.concurrent.Eventually._
+import org.scalatest.time.SpanSugar._
+
+import org.apache.spark.streaming._
+import org.apache.spark.streaming.scheduler.rate.RateEstimator
+
+class RateControllerSuite extends TestSuiteBase {
+
+  override def useManualClock: Boolean = false
+
+  test("rate controller publishes updates") {
+    val ssc = new StreamingContext(conf, batchDuration)
+    withStreamingContext(ssc) { ssc =>
+      val dstream = new RateLimitInputDStream(ssc)
+      dstream.register()
+      ssc.start()
+
+      eventually(timeout(10.seconds)) {
+        assert(dstream.publishCalls > 0)
+      }
+    }
+  }
+
+  test("publish rates reach receivers") {
+    val ssc = new StreamingContext(conf, batchDuration)
+    withStreamingContext(ssc) { ssc =>
+      val dstream = new RateLimitInputDStream(ssc) {
+        override val rateController =
+          Some(new ReceiverRateController(id, new ConstantEstimator(200.0)))
+      }
+      dstream.register()
+      SingletonTestRateReceiver.reset()
+      ssc.start()
+
+      eventually(timeout(10.seconds)) {
+        assert(dstream.getCurrentRateLimit === Some(200))
+      }
+    }
+  }
+
+  test("multiple publish rates reach receivers") {
+    val ssc = new StreamingContext(conf, batchDuration)
+    withStreamingContext(ssc) { ssc =>
+      val rates = Seq(100L, 200L, 300L)
+
+      val dstream = new RateLimitInputDStream(ssc) {
+        override val rateController =
+          Some(new ReceiverRateController(id, new 
ConstantEstimator(rates.map(_.toDouble): _*)))
+      }
+      SingletonTestRateReceiver.reset()
+      dstream.register()
+
+      val observedRates = mutable.HashSet.empty[Long]
+      ssc.start()
+
+      eventually(timeout(20.seconds)) {
+        dstream.getCurrentRateLimit.foreach(observedRates += _)
+        // Long.MaxValue (essentially, no rate limit) is the initial rate 
limit for any Receiver
+        observedRates should contain theSameElementsAs (rates :+ Long.MaxValue)
+      }
+    }
+  }
+}
+
+private[streaming] class ConstantEstimator(rates: Double*) extends 
RateEstimator {
+  private var idx: Int = 0
+
+  private def nextRate(): Double = {
+    val rate = rates(idx)
+    idx = (idx + 1) % rates.size
+    rate
+  }
+
+  def compute(
+      time: Long,
+      elements: Long,
+      processingDelay: Long,
+      schedulingDelay: Long): Option[Double] = Some(nextRate())
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/819be46e/streaming/src/test/scala/org/apache/spark/streaming/scheduler/ReceiverSchedulingPolicySuite.scala
----------------------------------------------------------------------
diff --git 
a/streaming/src/test/scala/org/apache/spark/streaming/scheduler/ReceiverSchedulingPolicySuite.scala
 
b/streaming/src/test/scala/org/apache/spark/streaming/scheduler/ReceiverSchedulingPolicySuite.scala
index 93f920f..0418d77 100644
--- 
a/streaming/src/test/scala/org/apache/spark/streaming/scheduler/ReceiverSchedulingPolicySuite.scala
+++ 
b/streaming/src/test/scala/org/apache/spark/streaming/scheduler/ReceiverSchedulingPolicySuite.scala
@@ -64,7 +64,7 @@ class ReceiverSchedulingPolicySuite extends SparkFunSuite {
 
   test("scheduleReceivers: " +
     "schedule receivers evenly when there are more receivers than executors") {
-    val receivers = (0 until 6).map(new DummyReceiver(_))
+    val receivers = (0 until 6).map(new RateTestReceiver(_))
     val executors = (10000 until 10003).map(port => s"localhost:${port}")
     val scheduledExecutors = 
receiverSchedulingPolicy.scheduleReceivers(receivers, executors)
     val numReceiversOnExecutor = mutable.HashMap[String, Int]()
@@ -79,7 +79,7 @@ class ReceiverSchedulingPolicySuite extends SparkFunSuite {
 
   test("scheduleReceivers: " +
     "schedule receivers evenly when there are more executors than receivers") {
-    val receivers = (0 until 3).map(new DummyReceiver(_))
+    val receivers = (0 until 3).map(new RateTestReceiver(_))
     val executors = (10000 until 10006).map(port => s"localhost:${port}")
     val scheduledExecutors = 
receiverSchedulingPolicy.scheduleReceivers(receivers, executors)
     val numReceiversOnExecutor = mutable.HashMap[String, Int]()
@@ -94,8 +94,8 @@ class ReceiverSchedulingPolicySuite extends SparkFunSuite {
   }
 
   test("scheduleReceivers: schedule receivers evenly when the 
preferredLocations are even") {
-    val receivers = (0 until 3).map(new DummyReceiver(_)) ++
-      (3 until 6).map(new DummyReceiver(_, Some("localhost")))
+    val receivers = (0 until 3).map(new RateTestReceiver(_)) ++
+      (3 until 6).map(new RateTestReceiver(_, Some("localhost")))
     val executors = (10000 until 10003).map(port => s"localhost:${port}") ++
       (10003 until 10006).map(port => s"localhost2:${port}")
     val scheduledExecutors = 
receiverSchedulingPolicy.scheduleReceivers(receivers, executors)
@@ -121,7 +121,7 @@ class ReceiverSchedulingPolicySuite extends SparkFunSuite {
   }
 
   test("scheduleReceivers: return empty scheduled executors if no executors") {
-    val receivers = (0 until 3).map(new DummyReceiver(_))
+    val receivers = (0 until 3).map(new RateTestReceiver(_))
     val scheduledExecutors = 
receiverSchedulingPolicy.scheduleReceivers(receivers, Seq.empty)
     scheduledExecutors.foreach { case (receiverId, executors) =>
       assert(executors.isEmpty)

http://git-wip-us.apache.org/repos/asf/spark/blob/819be46e/streaming/src/test/scala/org/apache/spark/streaming/scheduler/ReceiverTrackerSuite.scala
----------------------------------------------------------------------
diff --git 
a/streaming/src/test/scala/org/apache/spark/streaming/scheduler/ReceiverTrackerSuite.scala
 
b/streaming/src/test/scala/org/apache/spark/streaming/scheduler/ReceiverTrackerSuite.scala
index b039233..aff8b53 100644
--- 
a/streaming/src/test/scala/org/apache/spark/streaming/scheduler/ReceiverTrackerSuite.scala
+++ 
b/streaming/src/test/scala/org/apache/spark/streaming/scheduler/ReceiverTrackerSuite.scala
@@ -43,6 +43,7 @@ class ReceiverTrackerSuite extends TestSuiteBase {
 
     ssc.addStreamingListener(ReceiverStartedWaiter)
     ssc.scheduler.listenerBus.start(ssc.sc)
+    SingletonTestRateReceiver.reset()
 
     val newRateLimit = 100L
     val inputDStream = new RateLimitInputDStream(ssc)
@@ -62,36 +63,62 @@ class ReceiverTrackerSuite extends TestSuiteBase {
   }
 }
 
-/** An input DStream with a hard-coded receiver that gives access to internals 
for testing. */
-private class RateLimitInputDStream(@transient ssc_ : StreamingContext)
+/**
+ * An input DStream with a hard-coded receiver that gives access to internals 
for testing.
+ *
+ * @note Make sure to call {{{SingletonDummyReceiver.reset()}}} before using 
this in a test,
+ *       or otherwise you may get {{{NotSerializableException}}} when trying 
to serialize
+ *       the receiver.
+ * @see [[[SingletonDummyReceiver]]].
+ */
+private[streaming] class RateLimitInputDStream(@transient ssc_ : 
StreamingContext)
   extends ReceiverInputDStream[Int](ssc_) {
 
-  override def getReceiver(): DummyReceiver = SingletonDummyReceiver
+  override def getReceiver(): RateTestReceiver = SingletonTestRateReceiver
 
   def getCurrentRateLimit: Option[Long] = {
     invokeExecutorMethod.getCurrentRateLimit
   }
 
+  @volatile
+  var publishCalls = 0
+
+  override val rateController: Option[RateController] = {
+    Some(new RateController(id, new ConstantEstimator(100.0)) {
+      override def publish(rate: Long): Unit = {
+        publishCalls += 1
+      }
+    })
+  }
+
   private def invokeExecutorMethod: ReceiverSupervisor = {
     val c = classOf[Receiver[_]]
     val ex = c.getDeclaredMethod("executor")
     ex.setAccessible(true)
-    ex.invoke(SingletonDummyReceiver).asInstanceOf[ReceiverSupervisor]
+    ex.invoke(SingletonTestRateReceiver).asInstanceOf[ReceiverSupervisor]
   }
 }
 
 /**
- * A Receiver as an object so we can read its rate limit.
+ * A Receiver as an object so we can read its rate limit. Make sure to call 
`reset()` when
+ * reusing this receiver, otherwise a non-null `executor_` field will prevent 
it from being
+ * serialized when receivers are installed on executors.
  *
  * @note It's necessary to be a top-level object, or else serialization would 
create another
  *       one on the executor side and we won't be able to read its rate limit.
  */
-private object SingletonDummyReceiver extends DummyReceiver(0)
+private[streaming] object SingletonTestRateReceiver extends 
RateTestReceiver(0) {
+
+  /** Reset the object to be usable in another test. */
+  def reset(): Unit = {
+    executor_ = null
+  }
+}
 
 /**
  * Dummy receiver implementation
  */
-private class DummyReceiver(receiverId: Int, host: Option[String] = None)
+private[streaming] class RateTestReceiver(receiverId: Int, host: 
Option[String] = None)
   extends Receiver[Int](StorageLevel.MEMORY_ONLY) {
 
   setReceiverId(receiverId)


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to