Repository: spark
Updated Branches:
  refs/heads/branch-1.6 6c1bf19e8 -> 05666e09b


[SPARK-11663][STREAMING] Add Java API for trackStateByKey

TODO
- [x] Add Java API
- [x] Add API tests
- [x] Add a function test

Author: Shixiong Zhu <shixi...@databricks.com>

Closes #9636 from zsxwing/java-track.

(cherry picked from commit 0f1d00a905614bb5eebf260566dbcb831158d445)
Signed-off-by: Tathagata Das <tathagata.das1...@gmail.com>


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/05666e09
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/05666e09
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/05666e09

Branch: refs/heads/branch-1.6
Commit: 05666e09bdceafe25540e674efdd6eb70fe37fa0
Parents: 6c1bf19
Author: Shixiong Zhu <shixi...@databricks.com>
Authored: Thu Nov 12 17:48:43 2015 -0800
Committer: Tathagata Das <tathagata.das1...@gmail.com>
Committed: Thu Nov 12 17:48:51 2015 -0800

----------------------------------------------------------------------
 .../spark/api/java/function/Function4.java      |  27 +++
 .../streaming/JavaStatefulNetworkWordCount.java |  45 ++--
 .../streaming/StatefulNetworkWordCount.scala    |   2 +-
 .../apache/spark/streaming/Java8APISuite.java   |  43 ++++
 .../org/apache/spark/streaming/State.scala      |  25 ++-
 .../org/apache/spark/streaming/StateSpec.scala  |  84 ++++++--
 .../streaming/api/java/JavaPairDStream.scala    |  46 +++-
 .../api/java/JavaTrackStateDStream.scala        |  44 ++++
 .../streaming/dstream/TrackStateDStream.scala   |   1 +
 .../spark/streaming/rdd/TrackStateRDD.scala     |   4 +-
 .../apache/spark/streaming/util/StateMap.scala  |   6 +-
 .../streaming/JavaTrackStateByKeySuite.java     | 210 +++++++++++++++++++
 12 files changed, 485 insertions(+), 52 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/05666e09/core/src/main/java/org/apache/spark/api/java/function/Function4.java
----------------------------------------------------------------------
diff --git 
a/core/src/main/java/org/apache/spark/api/java/function/Function4.java 
b/core/src/main/java/org/apache/spark/api/java/function/Function4.java
new file mode 100644
index 0000000..fd727d6
--- /dev/null
+++ b/core/src/main/java/org/apache/spark/api/java/function/Function4.java
@@ -0,0 +1,27 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.api.java.function;
+
+import java.io.Serializable;
+
+/**
+ * A four-argument function that takes arguments of type T1, T2, T3 and T4 and 
returns an R.
+ */
+public interface Function4<T1, T2, T3, T4, R> extends Serializable {
+  public R call(T1 v1, T2 v2, T3 v3, T4 v4) throws Exception;
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/05666e09/examples/src/main/java/org/apache/spark/examples/streaming/JavaStatefulNetworkWordCount.java
----------------------------------------------------------------------
diff --git 
a/examples/src/main/java/org/apache/spark/examples/streaming/JavaStatefulNetworkWordCount.java
 
b/examples/src/main/java/org/apache/spark/examples/streaming/JavaStatefulNetworkWordCount.java
index 99b63a2..c400e42 100644
--- 
a/examples/src/main/java/org/apache/spark/examples/streaming/JavaStatefulNetworkWordCount.java
+++ 
b/examples/src/main/java/org/apache/spark/examples/streaming/JavaStatefulNetworkWordCount.java
@@ -26,18 +26,15 @@ import scala.Tuple2;
 import com.google.common.base.Optional;
 import com.google.common.collect.Lists;
 
-import org.apache.spark.HashPartitioner;
 import org.apache.spark.SparkConf;
+import org.apache.spark.api.java.function.*;
 import org.apache.spark.api.java.JavaPairRDD;
 import org.apache.spark.api.java.StorageLevels;
-import org.apache.spark.api.java.function.FlatMapFunction;
-import org.apache.spark.api.java.function.Function2;
-import org.apache.spark.api.java.function.PairFunction;
 import org.apache.spark.streaming.Durations;
-import org.apache.spark.streaming.api.java.JavaDStream;
-import org.apache.spark.streaming.api.java.JavaPairDStream;
-import org.apache.spark.streaming.api.java.JavaReceiverInputDStream;
-import org.apache.spark.streaming.api.java.JavaStreamingContext;
+import org.apache.spark.streaming.State;
+import org.apache.spark.streaming.StateSpec;
+import org.apache.spark.streaming.Time;
+import org.apache.spark.streaming.api.java.*;
 
 /**
  * Counts words cumulatively in UTF8 encoded, '\n' delimited text received 
from the network every
@@ -63,25 +60,12 @@ public class JavaStatefulNetworkWordCount {
 
     StreamingExamples.setStreamingLogLevels();
 
-    // Update the cumulative count function
-    final Function2<List<Integer>, Optional<Integer>, Optional<Integer>> 
updateFunction =
-        new Function2<List<Integer>, Optional<Integer>, Optional<Integer>>() {
-          @Override
-          public Optional<Integer> call(List<Integer> values, 
Optional<Integer> state) {
-            Integer newSum = state.or(0);
-            for (Integer value : values) {
-              newSum += value;
-            }
-            return Optional.of(newSum);
-          }
-        };
-
     // Create the context with a 1 second batch size
     SparkConf sparkConf = new 
SparkConf().setAppName("JavaStatefulNetworkWordCount");
     JavaStreamingContext ssc = new JavaStreamingContext(sparkConf, 
Durations.seconds(1));
     ssc.checkpoint(".");
 
-    // Initial RDD input to updateStateByKey
+    // Initial RDD input to trackStateByKey
     @SuppressWarnings("unchecked")
     List<Tuple2<String, Integer>> tuples = Arrays.asList(new Tuple2<String, 
Integer>("hello", 1),
             new Tuple2<String, Integer>("world", 1));
@@ -105,9 +89,22 @@ public class JavaStatefulNetworkWordCount {
           }
         });
 
+    // Update the cumulative count function
+    final Function4<Time, String, Optional<Integer>, State<Integer>, 
Optional<Tuple2<String, Integer>>> trackStateFunc =
+        new Function4<Time, String, Optional<Integer>, State<Integer>, 
Optional<Tuple2<String, Integer>>>() {
+
+          @Override
+          public Optional<Tuple2<String, Integer>> call(Time time, String 
word, Optional<Integer> one, State<Integer> state) {
+            int sum = one.or(0) + (state.exists() ? state.get() : 0);
+            Tuple2<String, Integer> output = new Tuple2<String, Integer>(word, 
sum);
+            state.update(sum);
+            return Optional.of(output);
+          }
+        };
+
     // This will give a Dstream made of state (which is the cumulative count 
of the words)
-    JavaPairDStream<String, Integer> stateDstream = 
wordsDstream.updateStateByKey(updateFunction,
-            new HashPartitioner(ssc.sparkContext().defaultParallelism()), 
initialRDD);
+    JavaTrackStateDStream<String, Integer, Integer, Tuple2<String, Integer>> 
stateDstream =
+        
wordsDstream.trackStateByKey(StateSpec.function(trackStateFunc).initialState(initialRDD));
 
     stateDstream.print();
     ssc.start();

http://git-wip-us.apache.org/repos/asf/spark/blob/05666e09/examples/src/main/scala/org/apache/spark/examples/streaming/StatefulNetworkWordCount.scala
----------------------------------------------------------------------
diff --git 
a/examples/src/main/scala/org/apache/spark/examples/streaming/StatefulNetworkWordCount.scala
 
b/examples/src/main/scala/org/apache/spark/examples/streaming/StatefulNetworkWordCount.scala
index be2ae0b..a4f847f 100644
--- 
a/examples/src/main/scala/org/apache/spark/examples/streaming/StatefulNetworkWordCount.scala
+++ 
b/examples/src/main/scala/org/apache/spark/examples/streaming/StatefulNetworkWordCount.scala
@@ -49,7 +49,7 @@ object StatefulNetworkWordCount {
     val ssc = new StreamingContext(sparkConf, Seconds(1))
     ssc.checkpoint(".")
 
-    // Initial RDD input to updateStateByKey
+    // Initial RDD input to trackStateByKey
     val initialRDD = ssc.sparkContext.parallelize(List(("hello", 1), ("world", 
1)))
 
     // Create a ReceiverInputDStream on target ip:port and count the

http://git-wip-us.apache.org/repos/asf/spark/blob/05666e09/extras/java8-tests/src/test/java/org/apache/spark/streaming/Java8APISuite.java
----------------------------------------------------------------------
diff --git 
a/extras/java8-tests/src/test/java/org/apache/spark/streaming/Java8APISuite.java
 
b/extras/java8-tests/src/test/java/org/apache/spark/streaming/Java8APISuite.java
index 73091cf..163ae92 100644
--- 
a/extras/java8-tests/src/test/java/org/apache/spark/streaming/Java8APISuite.java
+++ 
b/extras/java8-tests/src/test/java/org/apache/spark/streaming/Java8APISuite.java
@@ -31,9 +31,12 @@ import org.junit.Test;
 import org.apache.spark.HashPartitioner;
 import org.apache.spark.api.java.JavaPairRDD;
 import org.apache.spark.api.java.JavaRDD;
+import org.apache.spark.api.java.function.Function2;
+import org.apache.spark.api.java.function.Function4;
 import org.apache.spark.api.java.function.PairFunction;
 import org.apache.spark.streaming.api.java.JavaDStream;
 import org.apache.spark.streaming.api.java.JavaPairDStream;
+import org.apache.spark.streaming.api.java.JavaTrackStateDStream;
 
 /**
  * Most of these tests replicate org.apache.spark.streaming.JavaAPISuite using 
java 8
@@ -831,4 +834,44 @@ public class Java8APISuite extends 
LocalJavaStreamingContext implements Serializ
     Assert.assertEquals(expected, result);
   }
 
+  /**
+   * This test is only for testing the APIs. It's not necessary to run it.
+   */
+  public void testTrackStateByAPI() {
+    JavaPairRDD<String, Boolean> initialRDD = null;
+    JavaPairDStream<String, Integer> wordsDstream = null;
+
+    JavaTrackStateDStream<String, Integer, Boolean, Double> stateDstream =
+        wordsDstream.trackStateByKey(
+            StateSpec.<String, Integer, Boolean, Double> function((time, key, 
value, state) -> {
+              // Use all State's methods here
+              state.exists();
+              state.get();
+              state.isTimingOut();
+              state.remove();
+              state.update(true);
+              return Optional.of(2.0);
+            }).initialState(initialRDD)
+                .numPartitions(10)
+                .partitioner(new HashPartitioner(10))
+                .timeout(Durations.seconds(10)));
+
+    JavaPairDStream<String, Boolean> emittedRecords = 
stateDstream.stateSnapshots();
+
+    JavaTrackStateDStream<String, Integer, Boolean, Double> stateDstream2 =
+        wordsDstream.trackStateByKey(
+            StateSpec.<String, Integer, Boolean, Double>function((value, 
state) -> {
+              state.exists();
+              state.get();
+              state.isTimingOut();
+              state.remove();
+              state.update(true);
+              return 2.0;
+            }).initialState(initialRDD)
+                .numPartitions(10)
+                .partitioner(new HashPartitioner(10))
+                .timeout(Durations.seconds(10)));
+
+    JavaPairDStream<String, Boolean> emittedRecords2 = 
stateDstream2.stateSnapshots();
+  }
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/05666e09/streaming/src/main/scala/org/apache/spark/streaming/State.scala
----------------------------------------------------------------------
diff --git a/streaming/src/main/scala/org/apache/spark/streaming/State.scala 
b/streaming/src/main/scala/org/apache/spark/streaming/State.scala
index 7dd1b72..604e64f 100644
--- a/streaming/src/main/scala/org/apache/spark/streaming/State.scala
+++ b/streaming/src/main/scala/org/apache/spark/streaming/State.scala
@@ -50,9 +50,30 @@ import org.apache.spark.annotation.Experimental
  *
  * }}}
  *
- * Java example:
+ * Java example of using `State`:
  * {{{
- *      TODO(@zsxwing)
+ *    // A tracking function that maintains an integer state and return a 
String
+ *   Function2<Optional<Integer>, State<Integer>, Optional<String>> 
trackStateFunc =
+ *       new Function2<Optional<Integer>, State<Integer>, Optional<String>>() {
+ *
+ *         @Override
+ *         public Optional<String> call(Optional<Integer> one, State<Integer> 
state) {
+ *           if (state.exists()) {
+ *             int existingState = state.get(); // Get the existing state
+ *             boolean shouldRemove = ...; // Decide whether to remove the 
state
+ *             if (shouldRemove) {
+ *               state.remove(); // Remove the state
+ *             } else {
+ *               int newState = ...;
+ *               state.update(newState); // Set the new state
+ *             }
+ *           } else {
+ *             int initialState = ...; // Set the initial state
+ *             state.update(initialState);
+ *           }
+ *           // return something
+ *         }
+ *       };
  * }}}
  */
 @Experimental

http://git-wip-us.apache.org/repos/asf/spark/blob/05666e09/streaming/src/main/scala/org/apache/spark/streaming/StateSpec.scala
----------------------------------------------------------------------
diff --git 
a/streaming/src/main/scala/org/apache/spark/streaming/StateSpec.scala 
b/streaming/src/main/scala/org/apache/spark/streaming/StateSpec.scala
index c9fe35e..bea5b9d 100644
--- a/streaming/src/main/scala/org/apache/spark/streaming/StateSpec.scala
+++ b/streaming/src/main/scala/org/apache/spark/streaming/StateSpec.scala
@@ -17,15 +17,14 @@
 
 package org.apache.spark.streaming
 
-import scala.reflect.ClassTag
-
+import com.google.common.base.Optional
 import org.apache.spark.annotation.Experimental
-import org.apache.spark.api.java.JavaPairRDD
+import org.apache.spark.api.java.{JavaPairRDD, JavaUtils}
+import org.apache.spark.api.java.function.{Function2 => JFunction2, Function4 
=> JFunction4}
 import org.apache.spark.rdd.RDD
 import org.apache.spark.util.ClosureCleaner
 import org.apache.spark.{HashPartitioner, Partitioner}
 
-
 /**
  * :: Experimental ::
  * Abstract class representing all the specifications of the DStream 
transformation
@@ -49,12 +48,12 @@ import org.apache.spark.{HashPartitioner, Partitioner}
  *
  * Example in Java:
  * {{{
- *    StateStateSpec[KeyType, ValueType, StateType, EmittedDataType] spec =
- *      StateStateSpec.function[KeyType, ValueType, StateType, 
EmittedDataType](trackingFunction)
+ *    StateSpec<KeyType, ValueType, StateType, EmittedDataType> spec =
+ *      StateSpec.<KeyType, ValueType, StateType, 
EmittedDataType>function(trackingFunction)
  *                    .numPartition(10);
  *
- *    JavaDStream[EmittedDataType] emittedRecordDStream =
- *      javaPairDStream.trackStateByKey[StateType, EmittedDataType](spec);
+ *    JavaTrackStateDStream<KeyType, ValueType, StateType, EmittedType> 
emittedRecordDStream =
+ *      javaPairDStream.<StateType, EmittedDataType>trackStateByKey(spec);
  * }}}
  */
 @Experimental
@@ -92,6 +91,7 @@ sealed abstract class StateSpec[KeyType, ValueType, 
StateType, EmittedType] exte
 /**
  * :: Experimental ::
  * Builder object for creating instances of 
[[org.apache.spark.streaming.StateSpec StateSpec]]
+ * that is used for specifying the parameters of the DStream transformation 
`trackStateByKey`
  * that is used for specifying the parameters of the DStream transformation
  * `trackStateByKey` operation of a
  * [[org.apache.spark.streaming.dstream.PairDStreamFunctions pair DStream]] 
(Scala) or a
@@ -103,28 +103,27 @@ sealed abstract class StateSpec[KeyType, ValueType, 
StateType, EmittedType] exte
  *      ...
  *    }
  *
- *    val spec = StateSpec.function(trackingFunction).numPartitions(10)
- *
- *    val emittedRecordDStream = keyValueDStream.trackStateByKey[StateType, 
EmittedDataType](spec)
+ *    val emittedRecordDStream = keyValueDStream.trackStateByKey[StateType, 
EmittedDataType](
+ *        StateSpec.function(trackingFunction).numPartitions(10))
  * }}}
  *
  * Example in Java:
  * {{{
- *    StateStateSpec[KeyType, ValueType, StateType, EmittedDataType] spec =
- *      StateStateSpec.function[KeyType, ValueType, StateType, 
EmittedDataType](trackingFunction)
+ *    StateSpec<KeyType, ValueType, StateType, EmittedDataType> spec =
+ *      StateSpec.<KeyType, ValueType, StateType, 
EmittedDataType>function(trackingFunction)
  *                    .numPartition(10);
  *
- *    JavaDStream[EmittedDataType] emittedRecordDStream =
- *      javaPairDStream.trackStateByKey[StateType, EmittedDataType](spec);
+ *    JavaTrackStateDStream<KeyType, ValueType, StateType, EmittedType> 
emittedRecordDStream =
+ *      javaPairDStream.<StateType, EmittedDataType>trackStateByKey(spec);
  * }}}
  */
 @Experimental
 object StateSpec {
   /**
    * Create a [[org.apache.spark.streaming.StateSpec StateSpec]] for setting 
all the specifications
-   * `trackStateByKey` operation on a
-   * [[org.apache.spark.streaming.dstream.PairDStreamFunctions pair DStream]] 
(Scala) or a
-   * [[org.apache.spark.streaming.api.java.JavaPairDStream JavaPairDStream]] 
(Java).
+   * of the `trackStateByKey` operation on a
+   * [[org.apache.spark.streaming.dstream.PairDStreamFunctions pair DStream]].
+   *
    * @param trackingFunction The function applied on every data item to manage 
the associated state
    *                         and generate the emitted data
    * @tparam KeyType      Class of the keys
@@ -141,9 +140,9 @@ object StateSpec {
 
   /**
    * Create a [[org.apache.spark.streaming.StateSpec StateSpec]] for setting 
all the specifications
-   * `trackStateByKey` operation on a
-   * [[org.apache.spark.streaming.dstream.PairDStreamFunctions pair DStream]] 
(Scala) or a
-   * [[org.apache.spark.streaming.api.java.JavaPairDStream JavaPairDStream]] 
(Java).
+   * of the `trackStateByKey` operation on a
+   * [[org.apache.spark.streaming.dstream.PairDStreamFunctions pair DStream]].
+   *
    * @param trackingFunction The function applied on every data item to manage 
the associated state
    *                         and generate the emitted data
    * @tparam ValueType    Class of the values
@@ -160,6 +159,48 @@ object StateSpec {
       }
     new StateSpecImpl(wrappedFunction)
   }
+
+  /**
+   * Create a [[org.apache.spark.streaming.StateSpec StateSpec]] for setting 
all
+   * the specifications of the `trackStateByKey` operation on a
+   * [[org.apache.spark.streaming.api.java.JavaPairDStream JavaPairDStream]].
+   *
+   * @param javaTrackingFunction The function applied on every data item to 
manage the associated
+   *                             state and generate the emitted data
+   * @tparam KeyType      Class of the keys
+   * @tparam ValueType    Class of the values
+   * @tparam StateType    Class of the states data
+   * @tparam EmittedType  Class of the emitted data
+   */
+  def function[KeyType, ValueType, StateType, 
EmittedType](javaTrackingFunction:
+      JFunction4[Time, KeyType, Optional[ValueType], State[StateType], 
Optional[EmittedType]]):
+    StateSpec[KeyType, ValueType, StateType, EmittedType] = {
+    val trackingFunc = (time: Time, k: KeyType, v: Option[ValueType], s: 
State[StateType]) => {
+      val t = javaTrackingFunction.call(time, k, 
JavaUtils.optionToOptional(v), s)
+      Option(t.orNull)
+    }
+    StateSpec.function(trackingFunc)
+  }
+
+  /**
+   * Create a [[org.apache.spark.streaming.StateSpec StateSpec]] for setting 
all the specifications
+   * of the `trackStateByKey` operation on a
+   * [[org.apache.spark.streaming.api.java.JavaPairDStream JavaPairDStream]].
+   *
+   * @param javaTrackingFunction The function applied on every data item to 
manage the associated
+   *                             state and generate the emitted data
+   * @tparam ValueType    Class of the values
+   * @tparam StateType    Class of the states data
+   * @tparam EmittedType  Class of the emitted data
+   */
+  def function[KeyType, ValueType, StateType, EmittedType](
+      javaTrackingFunction: JFunction2[Optional[ValueType], State[StateType], 
EmittedType]):
+    StateSpec[KeyType, ValueType, StateType, EmittedType] = {
+    val trackingFunc = (v: Option[ValueType], s: State[StateType]) => {
+      javaTrackingFunction.call(Optional.fromNullable(v.get), s)
+    }
+    StateSpec.function(trackingFunc)
+  }
 }
 
 
@@ -184,7 +225,6 @@ case class StateSpecImpl[K, V, S, T](
     this
   }
 
-
   override def numPartitions(numPartitions: Int): this.type = {
     this.partitioner(new HashPartitioner(numPartitions))
     this

http://git-wip-us.apache.org/repos/asf/spark/blob/05666e09/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaPairDStream.scala
----------------------------------------------------------------------
diff --git 
a/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaPairDStream.scala
 
b/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaPairDStream.scala
index e2aec6c..70e32b3 100644
--- 
a/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaPairDStream.scala
+++ 
b/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaPairDStream.scala
@@ -28,8 +28,10 @@ import com.google.common.base.Optional
 import org.apache.hadoop.conf.Configuration
 import org.apache.hadoop.mapred.{JobConf, OutputFormat}
 import org.apache.hadoop.mapreduce.{OutputFormat => NewOutputFormat}
+
 import org.apache.spark.Partitioner
-import org.apache.spark.api.java.{JavaPairRDD, JavaUtils}
+import org.apache.spark.annotation.Experimental
+import org.apache.spark.api.java.{JavaPairRDD, JavaSparkContext, JavaUtils}
 import org.apache.spark.api.java.JavaPairRDD._
 import org.apache.spark.api.java.JavaSparkContext.fakeClassTag
 import org.apache.spark.api.java.function.{Function => JFunction, Function2 => 
JFunction2}
@@ -426,6 +428,48 @@ class JavaPairDStream[K, V](val dstream: DStream[(K, V)])(
     )
   }
 
+  /**
+   * :: Experimental ::
+   * Return a new [[JavaDStream]] of data generated by combining the key-value 
data in `this` stream
+   * with a continuously updated per-key state. The user-provided state 
tracking function is
+   * applied on each keyed data item along with its corresponding state. The 
function can choose to
+   * update/remove the state and return a transformed data, which forms the
+   * [[JavaTrackStateDStream]].
+   *
+   * The specifications of this transformation is made through the
+   * [[org.apache.spark.streaming.StateSpec StateSpec]] class. Besides the 
tracking function, there
+   * are a number of optional parameters - initial state data, number of 
partitions, timeouts, etc.
+   * See the [[org.apache.spark.streaming.StateSpec StateSpec]] for more 
details.
+   *
+   * Example of using `trackStateByKey`:
+   * {{{
+   *   // A tracking function that maintains an integer state and return a 
String
+   *   Function2<Optional<Integer>, State<Integer>, Optional<String>> 
trackStateFunc =
+   *       new Function2<Optional<Integer>, State<Integer>, 
Optional<String>>() {
+   *
+   *         @Override
+   *         public Optional<String> call(Optional<Integer> one, 
State<Integer> state) {
+   *           // Check if state exists, accordingly update/remove state and 
return transformed data
+   *         }
+   *       };
+   *
+   *    JavaTrackStateDStream<Integer, Integer, Integer, String> 
trackStateDStream =
+   *        keyValueDStream.<Integer, String>trackStateByKey(
+   *                 StateSpec.function(trackStateFunc).numPartitions(10));
+   * }}}
+   *
+   * @param spec          Specification of this transformation
+   * @tparam StateType    Class type of the state
+   * @tparam EmittedType  Class type of the tranformed data return by the 
tracking function
+   */
+  @Experimental
+  def trackStateByKey[StateType, EmittedType](spec: StateSpec[K, V, StateType, 
EmittedType]):
+    JavaTrackStateDStream[K, V, StateType, EmittedType] = {
+    new JavaTrackStateDStream(dstream.trackStateByKey(spec)(
+      JavaSparkContext.fakeClassTag,
+      JavaSparkContext.fakeClassTag))
+  }
+
   private def convertUpdateStateFunction[S](in: JFunction2[JList[V], 
Optional[S], Optional[S]]):
   (Seq[V], Option[S]) => Option[S] = {
     val scalaFunc: (Seq[V], Option[S]) => Option[S] = (values, state) => {

http://git-wip-us.apache.org/repos/asf/spark/blob/05666e09/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaTrackStateDStream.scala
----------------------------------------------------------------------
diff --git 
a/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaTrackStateDStream.scala
 
b/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaTrackStateDStream.scala
new file mode 100644
index 0000000..f459930
--- /dev/null
+++ 
b/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaTrackStateDStream.scala
@@ -0,0 +1,44 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.streaming.api.java
+
+import org.apache.spark.annotation.Experimental
+import org.apache.spark.api.java.JavaSparkContext
+import org.apache.spark.streaming.dstream.TrackStateDStream
+
+/**
+ * :: Experimental ::
+ * [[JavaDStream]] representing the stream of records emitted by the tracking 
function in the
+ * `trackStateByKey` operation on a [[JavaPairDStream]]. Additionally, it also 
gives access to the
+ * stream of state snapshots, that is, the state data of all keys after a 
batch has updated them.
+ *
+ * @tparam KeyType Class of the state key
+ * @tparam ValueType Class of the state value
+ * @tparam StateType Class of the state
+ * @tparam EmittedType Class of the emitted records
+ */
+@Experimental
+class JavaTrackStateDStream[KeyType, ValueType, StateType, EmittedType](
+    dstream: TrackStateDStream[KeyType, ValueType, StateType, EmittedType])
+  extends JavaDStream[EmittedType](dstream)(JavaSparkContext.fakeClassTag) {
+
+  def stateSnapshots(): JavaPairDStream[KeyType, StateType] =
+    new JavaPairDStream(dstream.stateSnapshots())(
+      JavaSparkContext.fakeClassTag,
+      JavaSparkContext.fakeClassTag)
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/05666e09/streaming/src/main/scala/org/apache/spark/streaming/dstream/TrackStateDStream.scala
----------------------------------------------------------------------
diff --git 
a/streaming/src/main/scala/org/apache/spark/streaming/dstream/TrackStateDStream.scala
 
b/streaming/src/main/scala/org/apache/spark/streaming/dstream/TrackStateDStream.scala
index 58d89c9..98e881e 100644
--- 
a/streaming/src/main/scala/org/apache/spark/streaming/dstream/TrackStateDStream.scala
+++ 
b/streaming/src/main/scala/org/apache/spark/streaming/dstream/TrackStateDStream.scala
@@ -35,6 +35,7 @@ import org.apache.spark.streaming.rdd.{TrackStateRDD, 
TrackStateRDDRecord}
  * all keys after a batch has updated them.
  *
  * @tparam KeyType Class of the state key
+ * @tparam ValueType Class of the state value
  * @tparam StateType Class of the state data
  * @tparam EmittedType Class of the emitted records
  */

http://git-wip-us.apache.org/repos/asf/spark/blob/05666e09/streaming/src/main/scala/org/apache/spark/streaming/rdd/TrackStateRDD.scala
----------------------------------------------------------------------
diff --git 
a/streaming/src/main/scala/org/apache/spark/streaming/rdd/TrackStateRDD.scala 
b/streaming/src/main/scala/org/apache/spark/streaming/rdd/TrackStateRDD.scala
index ed7cea2..fc51496 100644
--- 
a/streaming/src/main/scala/org/apache/spark/streaming/rdd/TrackStateRDD.scala
+++ 
b/streaming/src/main/scala/org/apache/spark/streaming/rdd/TrackStateRDD.scala
@@ -70,12 +70,14 @@ private[streaming] class TrackStateRDDPartition(
  *                           in the `prevStateRDD` to create `this` RDD
  * @param trackingFunction The function that will be used to update state and 
return new data
  * @param batchTime        The time of the batch to which this RDD belongs to. 
Use to update
+ * @param timeoutThresholdTime The time to indicate which keys are timeout
  */
 private[streaming] class TrackStateRDD[K: ClassTag, V: ClassTag, S: ClassTag, 
T: ClassTag](
     private var prevStateRDD: RDD[TrackStateRDDRecord[K, S, T]],
     private var partitionedDataRDD: RDD[(K, V)],
     trackingFunction: (Time, K, Option[V], State[S]) => Option[T],
-    batchTime: Time, timeoutThresholdTime: Option[Long]
+    batchTime: Time,
+    timeoutThresholdTime: Option[Long]
   ) extends RDD[TrackStateRDDRecord[K, S, T]](
     partitionedDataRDD.sparkContext,
     List(

http://git-wip-us.apache.org/repos/asf/spark/blob/05666e09/streaming/src/main/scala/org/apache/spark/streaming/util/StateMap.scala
----------------------------------------------------------------------
diff --git 
a/streaming/src/main/scala/org/apache/spark/streaming/util/StateMap.scala 
b/streaming/src/main/scala/org/apache/spark/streaming/util/StateMap.scala
index ed622ef..34287c3 100644
--- a/streaming/src/main/scala/org/apache/spark/streaming/util/StateMap.scala
+++ b/streaming/src/main/scala/org/apache/spark/streaming/util/StateMap.scala
@@ -267,7 +267,11 @@ private[streaming] class OpenHashMapBasedStateMap[K: 
ClassTag, S: ClassTag](
 
     // Read the data of the delta
     val deltaMapSize = inputStream.readInt()
-    deltaMap = new OpenHashMap[K, StateInfo[S]]()
+    deltaMap = if (deltaMapSize != 0) {
+        new OpenHashMap[K, StateInfo[S]](deltaMapSize)
+      } else {
+        new OpenHashMap[K, StateInfo[S]](initialCapacity)
+      }
     var deltaMapCount = 0
     while (deltaMapCount < deltaMapSize) {
       val key = inputStream.readObject().asInstanceOf[K]

http://git-wip-us.apache.org/repos/asf/spark/blob/05666e09/streaming/src/test/java/org/apache/spark/streaming/JavaTrackStateByKeySuite.java
----------------------------------------------------------------------
diff --git 
a/streaming/src/test/java/org/apache/spark/streaming/JavaTrackStateByKeySuite.java
 
b/streaming/src/test/java/org/apache/spark/streaming/JavaTrackStateByKeySuite.java
new file mode 100644
index 0000000..eac4cdd
--- /dev/null
+++ 
b/streaming/src/test/java/org/apache/spark/streaming/JavaTrackStateByKeySuite.java
@@ -0,0 +1,210 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.streaming;
+
+import java.io.Serializable;
+import java.util.Arrays;
+import java.util.Collections;
+import java.util.List;
+import java.util.Set;
+
+import scala.Tuple2;
+
+import com.google.common.base.Optional;
+import com.google.common.collect.Lists;
+import com.google.common.collect.Sets;
+import org.apache.spark.api.java.JavaRDD;
+import org.apache.spark.api.java.function.Function;
+import org.apache.spark.streaming.api.java.JavaDStream;
+import org.apache.spark.util.ManualClock;
+import org.junit.Assert;
+import org.junit.Test;
+
+import org.apache.spark.HashPartitioner;
+import org.apache.spark.api.java.JavaPairRDD;
+import org.apache.spark.api.java.function.Function2;
+import org.apache.spark.api.java.function.Function4;
+import org.apache.spark.streaming.api.java.JavaPairDStream;
+import org.apache.spark.streaming.api.java.JavaTrackStateDStream;
+
+public class JavaTrackStateByKeySuite extends LocalJavaStreamingContext 
implements Serializable {
+
+  /**
+   * This test is only for testing the APIs. It's not necessary to run it.
+   */
+  public void testAPI() {
+    JavaPairRDD<String, Boolean> initialRDD = null;
+    JavaPairDStream<String, Integer> wordsDstream = null;
+
+    final Function4<Time, String, Optional<Integer>, State<Boolean>, 
Optional<Double>>
+        trackStateFunc =
+        new Function4<Time, String, Optional<Integer>, State<Boolean>, 
Optional<Double>>() {
+
+          @Override
+          public Optional<Double> call(
+              Time time, String word, Optional<Integer> one, State<Boolean> 
state) {
+            // Use all State's methods here
+            state.exists();
+            state.get();
+            state.isTimingOut();
+            state.remove();
+            state.update(true);
+            return Optional.of(2.0);
+          }
+        };
+
+    JavaTrackStateDStream<String, Integer, Boolean, Double> stateDstream =
+        wordsDstream.trackStateByKey(
+            StateSpec.function(trackStateFunc)
+                .initialState(initialRDD)
+                .numPartitions(10)
+                .partitioner(new HashPartitioner(10))
+                .timeout(Durations.seconds(10)));
+
+    JavaPairDStream<String, Boolean> emittedRecords = 
stateDstream.stateSnapshots();
+
+    final Function2<Optional<Integer>, State<Boolean>, Double> trackStateFunc2 
=
+        new Function2<Optional<Integer>, State<Boolean>, Double>() {
+
+          @Override
+          public Double call(Optional<Integer> one, State<Boolean> state) {
+            // Use all State's methods here
+            state.exists();
+            state.get();
+            state.isTimingOut();
+            state.remove();
+            state.update(true);
+            return 2.0;
+          }
+        };
+
+    JavaTrackStateDStream<String, Integer, Boolean, Double> stateDstream2 =
+        wordsDstream.trackStateByKey(
+            StateSpec.<String, Integer, Boolean, Double> 
function(trackStateFunc2)
+                .initialState(initialRDD)
+                .numPartitions(10)
+                .partitioner(new HashPartitioner(10))
+                .timeout(Durations.seconds(10)));
+
+    JavaPairDStream<String, Boolean> emittedRecords2 = 
stateDstream2.stateSnapshots();
+  }
+
+  @Test
+  public void testBasicFunction() {
+    List<List<String>> inputData = Arrays.asList(
+        Collections.<String>emptyList(),
+        Arrays.asList("a"),
+        Arrays.asList("a", "b"),
+        Arrays.asList("a", "b", "c"),
+        Arrays.asList("a", "b"),
+        Arrays.asList("a"),
+        Collections.<String>emptyList()
+    );
+
+    List<Set<Integer>> outputData = Arrays.asList(
+        Collections.<Integer>emptySet(),
+        Sets.newHashSet(1),
+        Sets.newHashSet(2, 1),
+        Sets.newHashSet(3, 2, 1),
+        Sets.newHashSet(4, 3),
+        Sets.newHashSet(5),
+        Collections.<Integer>emptySet()
+    );
+
+    List<Set<Tuple2<String, Integer>>> stateData = Arrays.asList(
+        Collections.<Tuple2<String, Integer>>emptySet(),
+        Sets.newHashSet(new Tuple2<String, Integer>("a", 1)),
+        Sets.newHashSet(new Tuple2<String, Integer>("a", 2), new 
Tuple2<String, Integer>("b", 1)),
+        Sets.newHashSet(
+            new Tuple2<String, Integer>("a", 3),
+            new Tuple2<String, Integer>("b", 2),
+            new Tuple2<String, Integer>("c", 1)),
+        Sets.newHashSet(
+            new Tuple2<String, Integer>("a", 4),
+            new Tuple2<String, Integer>("b", 3),
+            new Tuple2<String, Integer>("c", 1)),
+        Sets.newHashSet(
+            new Tuple2<String, Integer>("a", 5),
+            new Tuple2<String, Integer>("b", 3),
+            new Tuple2<String, Integer>("c", 1)),
+        Sets.newHashSet(
+            new Tuple2<String, Integer>("a", 5),
+            new Tuple2<String, Integer>("b", 3),
+            new Tuple2<String, Integer>("c", 1))
+    );
+
+    Function2<Optional<Integer>, State<Integer>, Integer> trackStateFunc =
+        new Function2<Optional<Integer>, State<Integer>, Integer>() {
+
+          @Override
+          public Integer call(Optional<Integer> value, State<Integer> state) 
throws Exception {
+            int sum = value.or(0) + (state.exists() ? state.get() : 0);
+            state.update(sum);
+            return sum;
+          }
+        };
+    testOperation(
+        inputData,
+        StateSpec.<String, Integer, Integer, Integer>function(trackStateFunc),
+        outputData,
+        stateData);
+  }
+
+  private <K, S, T> void testOperation(
+      List<List<K>> input,
+      StateSpec<K, Integer, S, T> trackStateSpec,
+      List<Set<T>> expectedOutputs,
+      List<Set<Tuple2<K, S>>> expectedStateSnapshots) {
+    int numBatches = expectedOutputs.size();
+    JavaDStream<K> inputStream = JavaTestUtils.attachTestInputStream(ssc, 
input, 2);
+    JavaTrackStateDStream<K, Integer, S, T> trackeStateStream =
+        JavaPairDStream.fromJavaDStream(inputStream.map(new Function<K, 
Tuple2<K, Integer>>() {
+          @Override
+          public Tuple2<K, Integer> call(K x) throws Exception {
+            return new Tuple2<K, Integer>(x, 1);
+          }
+        })).trackStateByKey(trackStateSpec);
+
+    final List<Set<T>> collectedOutputs =
+        Collections.synchronizedList(Lists.<Set<T>>newArrayList());
+    trackeStateStream.foreachRDD(new Function<JavaRDD<T>, Void>() {
+      @Override
+      public Void call(JavaRDD<T> rdd) throws Exception {
+        collectedOutputs.add(Sets.newHashSet(rdd.collect()));
+        return null;
+      }
+    });
+    final List<Set<Tuple2<K, S>>> collectedStateSnapshots =
+        Collections.synchronizedList(Lists.<Set<Tuple2<K, S>>>newArrayList());
+    trackeStateStream.stateSnapshots().foreachRDD(new Function<JavaPairRDD<K, 
S>, Void>() {
+      @Override
+      public Void call(JavaPairRDD<K, S> rdd) throws Exception {
+        collectedStateSnapshots.add(Sets.newHashSet(rdd.collect()));
+        return null;
+      }
+    });
+    BatchCounter batchCounter = new BatchCounter(ssc.ssc());
+    ssc.start();
+    ((ManualClock) ssc.ssc().scheduler().clock())
+        .advance(ssc.ssc().progressListener().batchDuration() * numBatches + 
1);
+    batchCounter.waitUntilBatchesCompleted(numBatches, 10000);
+
+    Assert.assertEquals(expectedOutputs, collectedOutputs);
+    Assert.assertEquals(expectedStateSnapshots, collectedStateSnapshots);
+  }
+}


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to