Repository: spark
Updated Branches:
  refs/heads/branch-1.4 2846a357f -> 59fc3f197


[SPARK-8200] [MLLIB] Check for empty RDDs in StreamingLinearAlgorithm

Test cases for both StreamingLinearRegression and StreamingLogisticRegression, 
and code fix.

Edit:
This contribution is my original work and I license the work to the project 
under the project's open source license.

Author: Paavo <ppark...@gmail.com>

Closes #6713 from pparkkin/streamingmodel-empty-rdd and squashes the following 
commits:

ff5cd78 [Paavo] Update strings to use interpolation.
db234cf [Paavo] Use !rdd.isEmpty.
54ad89e [Paavo] Test case for empty stream.
393e36f [Paavo] Ignore empty RDDs.
0bfc365 [Paavo] Test case for empty stream.

(cherry picked from commit b928f543845ddd39e914a0e8f0b0205fd86100c5)
Signed-off-by: Sean Owen <so...@cloudera.com>


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/59fc3f19
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/59fc3f19
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/59fc3f19

Branch: refs/heads/branch-1.4
Commit: 59fc3f197247c6c8c40ea7479573af023c89d718
Parents: 2846a35
Author: Paavo <ppark...@gmail.com>
Authored: Wed Jun 10 23:17:42 2015 +0100
Committer: Sean Owen <so...@cloudera.com>
Committed: Wed Jun 10 23:26:54 2015 +0100

----------------------------------------------------------------------
 .../regression/StreamingLinearAlgorithm.scala   | 28 +++++++++++---------
 .../StreamingLogisticRegressionSuite.scala      | 17 ++++++++++++
 .../StreamingLinearRegressionSuite.scala        | 18 +++++++++++++
 3 files changed, 50 insertions(+), 13 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/59fc3f19/mllib/src/main/scala/org/apache/spark/mllib/regression/StreamingLinearAlgorithm.scala
----------------------------------------------------------------------
diff --git 
a/mllib/src/main/scala/org/apache/spark/mllib/regression/StreamingLinearAlgorithm.scala
 
b/mllib/src/main/scala/org/apache/spark/mllib/regression/StreamingLinearAlgorithm.scala
index cea8f3f..2dd8aca 100644
--- 
a/mllib/src/main/scala/org/apache/spark/mllib/regression/StreamingLinearAlgorithm.scala
+++ 
b/mllib/src/main/scala/org/apache/spark/mllib/regression/StreamingLinearAlgorithm.scala
@@ -83,21 +83,23 @@ abstract class StreamingLinearAlgorithm[
       throw new IllegalArgumentException("Model must be initialized before 
starting training.")
     }
     data.foreachRDD { (rdd, time) =>
-      val initialWeights =
-        model match {
-          case Some(m) =>
-            m.weights
-          case None =>
-            val numFeatures = rdd.first().features.size
-            Vectors.dense(numFeatures)
+      if (!rdd.isEmpty) {
+        val initialWeights =
+          model match {
+            case Some(m) =>
+              m.weights
+            case None =>
+              val numFeatures = rdd.first().features.size
+              Vectors.dense(numFeatures)
+          }
+        model = Some(algorithm.run(rdd, initialWeights))
+        logInfo(s"Model updated at time ${time.toString}")
+        val display = model.get.weights.size match {
+          case x if x > 100 => 
model.get.weights.toArray.take(100).mkString("[", ",", "...")
+          case _ => model.get.weights.toArray.mkString("[", ",", "]")
         }
-      model = Some(algorithm.run(rdd, initialWeights))
-      logInfo("Model updated at time %s".format(time.toString))
-      val display = model.get.weights.size match {
-        case x if x > 100 => model.get.weights.toArray.take(100).mkString("[", 
",", "...")
-        case _ => model.get.weights.toArray.mkString("[", ",", "]")
+        logInfo(s"Current model: weights, ${display}")
       }
-      logInfo("Current model: weights, %s".format (display))
     }
   }
 

http://git-wip-us.apache.org/repos/asf/spark/blob/59fc3f19/mllib/src/test/scala/org/apache/spark/mllib/classification/StreamingLogisticRegressionSuite.scala
----------------------------------------------------------------------
diff --git 
a/mllib/src/test/scala/org/apache/spark/mllib/classification/StreamingLogisticRegressionSuite.scala
 
b/mllib/src/test/scala/org/apache/spark/mllib/classification/StreamingLogisticRegressionSuite.scala
index e98b61e..fd65329 100644
--- 
a/mllib/src/test/scala/org/apache/spark/mllib/classification/StreamingLogisticRegressionSuite.scala
+++ 
b/mllib/src/test/scala/org/apache/spark/mllib/classification/StreamingLogisticRegressionSuite.scala
@@ -158,4 +158,21 @@ class StreamingLogisticRegressionSuite extends 
SparkFunSuite with TestSuiteBase
     val error = output.map(batch => batch.map(p => math.abs(p._1 - p._2)).sum 
/ nPoints).toList
     assert(error.head > 0.8 & error.last < 0.2)
   }
+
+  // Test empty RDDs in a stream
+  test("handling empty RDDs in a stream") {
+    val model = new StreamingLogisticRegressionWithSGD()
+      .setInitialWeights(Vectors.dense(-0.1))
+      .setStepSize(0.01)
+      .setNumIterations(10)
+    val numBatches = 10
+    val emptyInput = Seq.empty[Seq[LabeledPoint]]
+    val ssc = setupStreams(emptyInput,
+      (inputDStream: DStream[LabeledPoint]) => {
+        model.trainOn(inputDStream)
+        model.predictOnValues(inputDStream.map(x => (x.label, x.features)))
+      }
+    )
+    val output: Seq[Seq[(Double, Double)]] = runStreams(ssc, numBatches, 
numBatches)
+  }
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/59fc3f19/mllib/src/test/scala/org/apache/spark/mllib/regression/StreamingLinearRegressionSuite.scala
----------------------------------------------------------------------
diff --git 
a/mllib/src/test/scala/org/apache/spark/mllib/regression/StreamingLinearRegressionSuite.scala
 
b/mllib/src/test/scala/org/apache/spark/mllib/regression/StreamingLinearRegressionSuite.scala
index 9a37940..f5e2d31 100644
--- 
a/mllib/src/test/scala/org/apache/spark/mllib/regression/StreamingLinearRegressionSuite.scala
+++ 
b/mllib/src/test/scala/org/apache/spark/mllib/regression/StreamingLinearRegressionSuite.scala
@@ -166,4 +166,22 @@ class StreamingLinearRegressionSuite extends SparkFunSuite 
with TestSuiteBase {
     val error = output.map(batch => batch.map(p => math.abs(p._1 - p._2)).sum 
/ nPoints).toList
     assert((error.head - error.last) > 2)
   }
+
+  // Test empty RDDs in a stream
+  test("handling empty RDDs in a stream") {
+    val model = new StreamingLinearRegressionWithSGD()
+      .setInitialWeights(Vectors.dense(0.0, 0.0))
+      .setStepSize(0.2)
+      .setNumIterations(25)
+    val numBatches = 10
+    val nPoints = 100
+    val emptyInput = Seq.empty[Seq[LabeledPoint]]
+    val ssc = setupStreams(emptyInput,
+      (inputDStream: DStream[LabeledPoint]) => {
+        model.trainOn(inputDStream)
+        model.predictOnValues(inputDStream.map(x => (x.label, x.features)))
+      }
+    )
+    val output: Seq[Seq[(Double, Double)]] = runStreams(ssc, numBatches, 
numBatches)
+  }
 }


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to