Repository: spark
Updated Branches:
  refs/heads/branch-1.3 400580228 -> 11b28b9b4


[SPARK-5601][MLLIB] make streaming linear algorithms Java-friendly

Overload `trainOn`, `predictOn`, and `predictOnValues`.

CC freeman-lab

Author: Xiangrui Meng <m...@databricks.com>

Closes #4432 from mengxr/streaming-java and squashes the following commits:

6a79b85 [Xiangrui Meng] add java test for streaming logistic regression
2d7b357 [Xiangrui Meng] organize imports
1f662b3 [Xiangrui Meng] make streaming linear algorithms Java-friendly

(cherry picked from commit 0e23ca9f805b46d9b3472330676e5c8db926b8f5)
Signed-off-by: Xiangrui Meng <m...@databricks.com>


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/11b28b9b
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/11b28b9b
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/11b28b9b

Branch: refs/heads/branch-1.3
Commit: 11b28b9b458a87e84edfc50caca01e8f9f9a2bdb
Parents: 4005802
Author: Xiangrui Meng <m...@databricks.com>
Authored: Fri Feb 6 15:42:59 2015 -0800
Committer: Xiangrui Meng <m...@databricks.com>
Committed: Fri Feb 6 15:43:05 2015 -0800

----------------------------------------------------------------------
 .../regression/StreamingLinearAlgorithm.scala   | 20 ++++-
 .../JavaStreamingLogisticRegressionSuite.java   | 82 ++++++++++++++++++++
 .../JavaStreamingLinearRegressionSuite.java     | 80 +++++++++++++++++++
 3 files changed, 181 insertions(+), 1 deletion(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/11b28b9b/mllib/src/main/scala/org/apache/spark/mllib/regression/StreamingLinearAlgorithm.scala
----------------------------------------------------------------------
diff --git 
a/mllib/src/main/scala/org/apache/spark/mllib/regression/StreamingLinearAlgorithm.scala
 
b/mllib/src/main/scala/org/apache/spark/mllib/regression/StreamingLinearAlgorithm.scala
index c854f12..ce95c06 100644
--- 
a/mllib/src/main/scala/org/apache/spark/mllib/regression/StreamingLinearAlgorithm.scala
+++ 
b/mllib/src/main/scala/org/apache/spark/mllib/regression/StreamingLinearAlgorithm.scala
@@ -21,7 +21,9 @@ import scala.reflect.ClassTag
 
 import org.apache.spark.Logging
 import org.apache.spark.annotation.DeveloperApi
+import org.apache.spark.api.java.JavaSparkContext.fakeClassTag
 import org.apache.spark.mllib.linalg.{Vector, Vectors}
+import org.apache.spark.streaming.api.java.{JavaDStream, JavaPairDStream}
 import org.apache.spark.streaming.dstream.DStream
 
 /**
@@ -76,7 +78,7 @@ abstract class StreamingLinearAlgorithm[
    *
    * @param data DStream containing labeled data
    */
-  def trainOn(data: DStream[LabeledPoint]) {
+  def trainOn(data: DStream[LabeledPoint]): Unit = {
     if (model.isEmpty) {
       throw new IllegalArgumentException("Model must be initialized before 
starting training.")
     }
@@ -99,6 +101,9 @@ abstract class StreamingLinearAlgorithm[
     }
   }
 
+  /** Java-friendly version of `trainOn`. */
+  def trainOn(data: JavaDStream[LabeledPoint]): Unit = trainOn(data.dstream)
+
   /**
    * Use the model to make predictions on batches of data from a DStream
    *
@@ -112,6 +117,11 @@ abstract class StreamingLinearAlgorithm[
     data.map(model.get.predict)
   }
 
+  /** Java-friendly version of `predictOn`. */
+  def predictOn(data: JavaDStream[Vector]): JavaDStream[java.lang.Double] = {
+    
JavaDStream.fromDStream(predictOn(data.dstream).asInstanceOf[DStream[java.lang.Double]])
+  }
+
   /**
    * Use the model to make predictions on the values of a DStream and carry 
over its keys.
    * @param data DStream containing feature vectors
@@ -124,4 +134,12 @@ abstract class StreamingLinearAlgorithm[
     }
     data.mapValues(model.get.predict)
   }
+
+
+  /** Java-friendly version of `predictOnValues`. */
+  def predictOnValues[K](data: JavaPairDStream[K, Vector]): JavaPairDStream[K, 
java.lang.Double] = {
+    implicit val tag = fakeClassTag[K]
+    JavaPairDStream.fromPairDStream(
+      predictOnValues(data.dstream).asInstanceOf[DStream[(K, 
java.lang.Double)]])
+  }
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/11b28b9b/mllib/src/test/java/org/apache/spark/ml/classification/JavaStreamingLogisticRegressionSuite.java
----------------------------------------------------------------------
diff --git 
a/mllib/src/test/java/org/apache/spark/ml/classification/JavaStreamingLogisticRegressionSuite.java
 
b/mllib/src/test/java/org/apache/spark/ml/classification/JavaStreamingLogisticRegressionSuite.java
new file mode 100644
index 0000000..ac945ba
--- /dev/null
+++ 
b/mllib/src/test/java/org/apache/spark/ml/classification/JavaStreamingLogisticRegressionSuite.java
@@ -0,0 +1,82 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ml.classification;
+
+import java.io.Serializable;
+import java.util.List;
+
+import scala.Tuple2;
+
+import com.google.common.collect.Lists;
+import org.junit.After;
+import org.junit.Before;
+import org.junit.Test;
+
+import org.apache.spark.SparkConf;
+import 
org.apache.spark.mllib.classification.StreamingLogisticRegressionWithSGD;
+import org.apache.spark.mllib.linalg.Vector;
+import org.apache.spark.mllib.linalg.Vectors;
+import org.apache.spark.mllib.regression.LabeledPoint;
+import org.apache.spark.streaming.Duration;
+import org.apache.spark.streaming.api.java.JavaDStream;
+import org.apache.spark.streaming.api.java.JavaPairDStream;
+import org.apache.spark.streaming.api.java.JavaStreamingContext;
+import static org.apache.spark.streaming.JavaTestUtils.*;
+
+public class JavaStreamingLogisticRegressionSuite implements Serializable {
+
+  protected transient JavaStreamingContext ssc;
+
+  @Before
+  public void setUp() {
+    SparkConf conf = new SparkConf()
+      .setMaster("local[2]")
+      .setAppName("test")
+      .set("spark.streaming.clock", 
"org.apache.spark.streaming.util.ManualClock");
+    ssc = new JavaStreamingContext(conf, new Duration(1000));
+    ssc.checkpoint("checkpoint");
+  }
+
+  @After
+  public void tearDown() {
+    ssc.stop();
+    ssc = null;
+  }
+
+  @Test
+  @SuppressWarnings("unchecked")
+  public void javaAPI() {
+    List<LabeledPoint> trainingBatch = Lists.newArrayList(
+      new LabeledPoint(1.0, Vectors.dense(1.0)),
+      new LabeledPoint(0.0, Vectors.dense(0.0)));
+    JavaDStream<LabeledPoint> training =
+      attachTestInputStream(ssc, Lists.newArrayList(trainingBatch, 
trainingBatch), 2);
+    List<Tuple2<Integer, Vector>> testBatch = Lists.newArrayList(
+      new Tuple2<Integer, Vector>(10, Vectors.dense(1.0)),
+      new Tuple2<Integer, Vector>(11, Vectors.dense(0.0)));
+    JavaPairDStream<Integer, Vector> test = JavaPairDStream.fromJavaDStream(
+      attachTestInputStream(ssc, Lists.newArrayList(testBatch, testBatch), 2));
+    StreamingLogisticRegressionWithSGD slr = new 
StreamingLogisticRegressionWithSGD()
+      .setNumIterations(2)
+      .setInitialWeights(Vectors.dense(0.0));
+    slr.trainOn(training);
+    JavaPairDStream<Integer, Double> prediction = slr.predictOnValues(test);
+    attachTestOutputStream(prediction.count());
+    runStreams(ssc, 2, 2);
+  }
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/11b28b9b/mllib/src/test/java/org/apache/spark/mllib/regression/JavaStreamingLinearRegressionSuite.java
----------------------------------------------------------------------
diff --git 
a/mllib/src/test/java/org/apache/spark/mllib/regression/JavaStreamingLinearRegressionSuite.java
 
b/mllib/src/test/java/org/apache/spark/mllib/regression/JavaStreamingLinearRegressionSuite.java
new file mode 100644
index 0000000..a4dd1ac
--- /dev/null
+++ 
b/mllib/src/test/java/org/apache/spark/mllib/regression/JavaStreamingLinearRegressionSuite.java
@@ -0,0 +1,80 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.mllib.regression;
+
+import java.io.Serializable;
+import java.util.List;
+
+import scala.Tuple2;
+
+import com.google.common.collect.Lists;
+import org.junit.After;
+import org.junit.Before;
+import org.junit.Test;
+
+import org.apache.spark.SparkConf;
+import org.apache.spark.mllib.linalg.Vector;
+import org.apache.spark.mllib.linalg.Vectors;
+import org.apache.spark.streaming.Duration;
+import org.apache.spark.streaming.api.java.JavaDStream;
+import org.apache.spark.streaming.api.java.JavaPairDStream;
+import org.apache.spark.streaming.api.java.JavaStreamingContext;
+import static org.apache.spark.streaming.JavaTestUtils.*;
+
+public class JavaStreamingLinearRegressionSuite implements Serializable {
+
+  protected transient JavaStreamingContext ssc;
+
+  @Before
+  public void setUp() {
+    SparkConf conf = new SparkConf()
+      .setMaster("local[2]")
+      .setAppName("test")
+      .set("spark.streaming.clock", 
"org.apache.spark.streaming.util.ManualClock");
+    ssc = new JavaStreamingContext(conf, new Duration(1000));
+    ssc.checkpoint("checkpoint");
+  }
+
+  @After
+  public void tearDown() {
+    ssc.stop();
+    ssc = null;
+  }
+
+  @Test
+  @SuppressWarnings("unchecked")
+  public void javaAPI() {
+    List<LabeledPoint> trainingBatch = Lists.newArrayList(
+      new LabeledPoint(1.0, Vectors.dense(1.0)),
+      new LabeledPoint(0.0, Vectors.dense(0.0)));
+    JavaDStream<LabeledPoint> training =
+      attachTestInputStream(ssc, Lists.newArrayList(trainingBatch, 
trainingBatch), 2);
+    List<Tuple2<Integer, Vector>> testBatch = Lists.newArrayList(
+      new Tuple2<Integer, Vector>(10, Vectors.dense(1.0)),
+      new Tuple2<Integer, Vector>(11, Vectors.dense(0.0)));
+    JavaPairDStream<Integer, Vector> test = JavaPairDStream.fromJavaDStream(
+      attachTestInputStream(ssc, Lists.newArrayList(testBatch, testBatch), 2));
+    StreamingLinearRegressionWithSGD slr = new 
StreamingLinearRegressionWithSGD()
+      .setNumIterations(2)
+      .setInitialWeights(Vectors.dense(0.0));
+    slr.trainOn(training);
+    JavaPairDStream<Integer, Double> prediction = slr.predictOnValues(test);
+    attachTestOutputStream(prediction.count());
+    runStreams(ssc, 2, 2);
+  }
+}


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to