Repository: spark
Updated Branches:
  refs/heads/master cbfc26ba4 -> 31f0b071e


[SPARK-3128][MLLIB] Use streaming test suite for StreamingLR

Refactored tests for streaming linear regression to use existing  streaming 
test utilities. Summary of changes:
- Made ``mllib`` depend on tests from ``streaming``
- Rewrote accuracy and convergence tests to use ``setupStreams`` and 
``runStreams``
- Added new test for the accuracy of predictions generated by ``predictOnValue``

These tests should run faster, be easier to extend/maintain, and provide a 
reference for new tests.

mengxr tdas

Author: freeman <the.freeman....@gmail.com>

Closes #2037 from freeman-lab/streamingLR-predict-tests and squashes the 
following commits:

e851ca7 [freeman] Fixed long lines
50eb0bf [freeman] Refactored tests to use streaming test tools
32c43c2 [freeman] Added test for prediction


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/31f0b071
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/31f0b071
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/31f0b071

Branch: refs/heads/master
Commit: 31f0b071efd0b63eb9d6a6a131e5c4fa28237583
Parents: cbfc26b
Author: freeman <the.freeman....@gmail.com>
Authored: Tue Aug 19 13:28:57 2014 -0700
Committer: Tathagata Das <tathagata.das1...@gmail.com>
Committed: Tue Aug 19 13:28:57 2014 -0700

----------------------------------------------------------------------
 mllib/pom.xml                                   |   7 ++
 .../StreamingLinearRegressionSuite.scala        | 121 ++++++++++---------
 .../apache/spark/streaming/TestSuiteBase.scala  |   4 +-
 3 files changed, 77 insertions(+), 55 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/31f0b071/mllib/pom.xml
----------------------------------------------------------------------
diff --git a/mllib/pom.xml b/mllib/pom.xml
index fc1ecfb..c7a1e2a 100644
--- a/mllib/pom.xml
+++ b/mllib/pom.xml
@@ -91,6 +91,13 @@
       <artifactId>junit-interface</artifactId>
       <scope>test</scope>
     </dependency>
+    <dependency>
+      <groupId>org.apache.spark</groupId>
+      <artifactId>spark-streaming_${scala.binary.version}</artifactId>
+      <version>${project.version}</version>
+      <type>test-jar</type>
+      <scope>test</scope>
+    </dependency>
   </dependencies>
   <profiles>
     <profile>

http://git-wip-us.apache.org/repos/asf/spark/blob/31f0b071/mllib/src/test/scala/org/apache/spark/mllib/regression/StreamingLinearRegressionSuite.scala
----------------------------------------------------------------------
diff --git 
a/mllib/src/test/scala/org/apache/spark/mllib/regression/StreamingLinearRegressionSuite.scala
 
b/mllib/src/test/scala/org/apache/spark/mllib/regression/StreamingLinearRegressionSuite.scala
index 45e25ee..2848941 100644
--- 
a/mllib/src/test/scala/org/apache/spark/mllib/regression/StreamingLinearRegressionSuite.scala
+++ 
b/mllib/src/test/scala/org/apache/spark/mllib/regression/StreamingLinearRegressionSuite.scala
@@ -17,20 +17,19 @@
 
 package org.apache.spark.mllib.regression
 
-import java.io.File
-import java.nio.charset.Charset
-
 import scala.collection.mutable.ArrayBuffer
 
-import com.google.common.io.Files
 import org.scalatest.FunSuite
 
 import org.apache.spark.mllib.linalg.Vectors
-import org.apache.spark.mllib.util.{LinearDataGenerator, LocalSparkContext}
-import org.apache.spark.streaming.{Milliseconds, StreamingContext}
-import org.apache.spark.util.Utils
+import org.apache.spark.mllib.util.LinearDataGenerator
+import org.apache.spark.streaming.dstream.DStream
+import org.apache.spark.streaming.TestSuiteBase
+
+class StreamingLinearRegressionSuite extends FunSuite with TestSuiteBase {
 
-class StreamingLinearRegressionSuite extends FunSuite with LocalSparkContext {
+  // use longer wait time to ensure job completion
+  override def maxWaitTimeMillis = 20000
 
   // Assert that two values are equal within tolerance epsilon
   def assertEqual(v1: Double, v2: Double, epsilon: Double) {
@@ -49,35 +48,26 @@ class StreamingLinearRegressionSuite extends FunSuite with 
LocalSparkContext {
   }
 
   // Test if we can accurately learn Y = 10*X1 + 10*X2 on streaming data
-  test("streaming linear regression parameter accuracy") {
+  test("parameter accuracy") {
 
-    val testDir = Files.createTempDir()
-    val numBatches = 10
-    val batchDuration = Milliseconds(1000)
-    val ssc = new StreamingContext(sc, batchDuration)
-    val data = ssc.textFileStream(testDir.toString).map(LabeledPoint.parse)
+    // create model
     val model = new StreamingLinearRegressionWithSGD()
       .setInitialWeights(Vectors.dense(0.0, 0.0))
       .setStepSize(0.1)
-      .setNumIterations(50)
+      .setNumIterations(25)
 
-    model.trainOn(data)
-
-    ssc.start()
-
-    // write data to a file stream
-    for (i <- 0 until numBatches) {
-      val samples = LinearDataGenerator.generateLinearInput(
-        0.0, Array(10.0, 10.0), 100, 42 * (i + 1))
-      val file = new File(testDir, i.toString)
-      Files.write(samples.map(x => x.toString).mkString("\n"), file, 
Charset.forName("UTF-8"))
-      Thread.sleep(batchDuration.milliseconds)
+    // generate sequence of simulated data
+    val numBatches = 10
+    val input = (0 until numBatches).map { i =>
+      LinearDataGenerator.generateLinearInput(0.0, Array(10.0, 10.0), 100, 42 
* (i + 1))
     }
 
-    ssc.stop(stopSparkContext=false)
-
-    System.clearProperty("spark.driver.port")
-    Utils.deleteRecursively(testDir)
+    // apply model training to input stream
+    val ssc = setupStreams(input, (inputDStream: DStream[LabeledPoint]) => {
+      model.trainOn(inputDStream)
+      inputDStream.count()
+    })
+    runStreams(ssc, numBatches, numBatches)
 
     // check accuracy of final parameter estimates
     assertEqual(model.latestModel().intercept, 0.0, 0.1)
@@ -91,39 +81,33 @@ class StreamingLinearRegressionSuite extends FunSuite with 
LocalSparkContext {
   }
 
   // Test that parameter estimates improve when learning Y = 10*X1 on 
streaming data
-  test("streaming linear regression parameter convergence") {
+  test("parameter convergence") {
 
-    val testDir = Files.createTempDir()
-    val batchDuration = Milliseconds(2000)
-    val ssc = new StreamingContext(sc, batchDuration)
-    val numBatches = 5
-    val data = ssc.textFileStream(testDir.toString()).map(LabeledPoint.parse)
+    // create model
     val model = new StreamingLinearRegressionWithSGD()
       .setInitialWeights(Vectors.dense(0.0))
       .setStepSize(0.1)
-      .setNumIterations(50)
-
-    model.trainOn(data)
-
-    ssc.start()
+      .setNumIterations(25)
 
-    // write data to a file stream
-    val history = new ArrayBuffer[Double](numBatches)
-    for (i <- 0 until numBatches) {
-      val samples = LinearDataGenerator.generateLinearInput(0.0, Array(10.0), 
100, 42 * (i + 1))
-      val file = new File(testDir, i.toString)
-      Files.write(samples.map(x => x.toString).mkString("\n"), file, 
Charset.forName("UTF-8"))
-      Thread.sleep(batchDuration.milliseconds)
-      // wait an extra few seconds to make sure the update finishes before new 
data arrive
-      Thread.sleep(4000)
-      history.append(math.abs(model.latestModel().weights(0) - 10.0))
+    // generate sequence of simulated data
+    val numBatches = 10
+    val input = (0 until numBatches).map { i =>
+      LinearDataGenerator.generateLinearInput(0.0, Array(10.0), 100, 42 * (i + 
1))
     }
 
-    ssc.stop(stopSparkContext=false)
+    // create buffer to store intermediate fits
+    val history = new ArrayBuffer[Double](numBatches)
 
-    System.clearProperty("spark.driver.port")
-    Utils.deleteRecursively(testDir)
+    // apply model training to input stream, storing the intermediate results
+    // (we add a count to ensure the result is a DStream)
+    val ssc = setupStreams(input, (inputDStream: DStream[LabeledPoint]) => {
+      model.trainOn(inputDStream)
+      inputDStream.foreachRDD(x => 
history.append(math.abs(model.latestModel().weights(0) - 10.0)))
+      inputDStream.count()
+    })
+    runStreams(ssc, numBatches, numBatches)
 
+    // compute change in error
     val deltas = history.drop(1).zip(history.dropRight(1))
     // check error stability (it always either shrinks, or increases with 
small tol)
     assert(deltas.forall(x => (x._1 - x._2) <= 0.1))
@@ -132,4 +116,33 @@ class StreamingLinearRegressionSuite extends FunSuite with 
LocalSparkContext {
 
   }
 
+  // Test predictions on a stream
+  test("predictions") {
+
+    // create model initialized with true weights
+    val model = new StreamingLinearRegressionWithSGD()
+      .setInitialWeights(Vectors.dense(10.0, 10.0))
+      .setStepSize(0.1)
+      .setNumIterations(25)
+
+    // generate sequence of simulated data for testing
+    val numBatches = 10
+    val nPoints = 100
+    val testInput = (0 until numBatches).map { i =>
+      LinearDataGenerator.generateLinearInput(0.0, Array(10.0, 10.0), nPoints, 
42 * (i + 1))
+    }
+
+    // apply model predictions to test stream
+    val ssc = setupStreams(testInput, (inputDStream: DStream[LabeledPoint]) => 
{
+      model.predictOnValues(inputDStream.map(x => (x.label, x.features)))
+    })
+    // collect the output as (true, estimated) tuples
+    val output: Seq[Seq[(Double, Double)]] = runStreams(ssc, numBatches, 
numBatches)
+
+    // compute the mean absolute error and check that it's always less than 0.1
+    val errors = output.map(batch => batch.map(p => math.abs(p._1 - p._2)).sum 
/ nPoints)
+    assert(errors.forall(x => x <= 0.1))
+
+  }
+
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/31f0b071/streaming/src/test/scala/org/apache/spark/streaming/TestSuiteBase.scala
----------------------------------------------------------------------
diff --git 
a/streaming/src/test/scala/org/apache/spark/streaming/TestSuiteBase.scala 
b/streaming/src/test/scala/org/apache/spark/streaming/TestSuiteBase.scala
index cc178fb..f095da9 100644
--- a/streaming/src/test/scala/org/apache/spark/streaming/TestSuiteBase.scala
+++ b/streaming/src/test/scala/org/apache/spark/streaming/TestSuiteBase.scala
@@ -242,7 +242,9 @@ trait TestSuiteBase extends FunSuite with BeforeAndAfter 
with Logging {
     logInfo("numBatches = " + numBatches + ", numExpectedOutput = " + 
numExpectedOutput)
 
     // Get the output buffer
-    val outputStream = 
ssc.graph.getOutputStreams.head.asInstanceOf[TestOutputStreamWithPartitions[V]]
+    val outputStream = ssc.graph.getOutputStreams.
+      filter(_.isInstanceOf[TestOutputStreamWithPartitions[_]]).
+      head.asInstanceOf[TestOutputStreamWithPartitions[V]]
     val output = outputStream.output
 
     try {


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to