Repository: spark
Updated Branches:
  refs/heads/master 7a3f589ef -> 69c67abaa


http://git-wip-us.apache.org/repos/asf/spark/blob/69c67aba/python/pyspark/streaming/tests.py
----------------------------------------------------------------------
diff --git a/python/pyspark/streaming/tests.py 
b/python/pyspark/streaming/tests.py
new file mode 100644
index 0000000..a8d876d
--- /dev/null
+++ b/python/pyspark/streaming/tests.py
@@ -0,0 +1,545 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements.  See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License.  You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import os
+from itertools import chain
+import time
+import operator
+import unittest
+import tempfile
+
+from pyspark.context import SparkConf, SparkContext, RDD
+from pyspark.streaming.context import StreamingContext
+
+
+class PySparkStreamingTestCase(unittest.TestCase):
+
+    timeout = 10  # seconds
+    duration = 1
+
+    def setUp(self):
+        class_name = self.__class__.__name__
+        conf = SparkConf().set("spark.default.parallelism", 1)
+        self.sc = SparkContext(appName=class_name, conf=conf)
+        self.sc.setCheckpointDir("/tmp")
+        # TODO: decrease duration to speed up tests
+        self.ssc = StreamingContext(self.sc, self.duration)
+
+    def tearDown(self):
+        self.ssc.stop()
+
+    def wait_for(self, result, n):
+        start_time = time.time()
+        while len(result) < n and time.time() - start_time < self.timeout:
+            time.sleep(0.01)
+        if len(result) < n:
+            print "timeout after", self.timeout
+
+    def _take(self, dstream, n):
+        """
+        Return the first `n` elements in the stream (will start and stop).
+        """
+        results = []
+
+        def take(_, rdd):
+            if rdd and len(results) < n:
+                results.extend(rdd.take(n - len(results)))
+
+        dstream.foreachRDD(take)
+
+        self.ssc.start()
+        self.wait_for(results, n)
+        return results
+
+    def _collect(self, dstream, n, block=True):
+        """
+        Collect each RDDs into the returned list.
+
+        :return: list, which will have the collected items.
+        """
+        result = []
+
+        def get_output(_, rdd):
+            if rdd and len(result) < n:
+                r = rdd.collect()
+                if r:
+                    result.append(r)
+
+        dstream.foreachRDD(get_output)
+
+        if not block:
+            return result
+
+        self.ssc.start()
+        self.wait_for(result, n)
+        return result
+
+    def _test_func(self, input, func, expected, sort=False, input2=None):
+        """
+        @param input: dataset for the test. This should be list of lists.
+        @param func: wrapped function. This function should return 
PythonDStream object.
+        @param expected: expected output for this testcase.
+        """
+        if not isinstance(input[0], RDD):
+            input = [self.sc.parallelize(d, 1) for d in input]
+        input_stream = self.ssc.queueStream(input)
+        if input2 and not isinstance(input2[0], RDD):
+            input2 = [self.sc.parallelize(d, 1) for d in input2]
+        input_stream2 = self.ssc.queueStream(input2) if input2 is not None 
else None
+
+        # Apply test function to stream.
+        if input2:
+            stream = func(input_stream, input_stream2)
+        else:
+            stream = func(input_stream)
+
+        result = self._collect(stream, len(expected))
+        if sort:
+            self._sort_result_based_on_key(result)
+            self._sort_result_based_on_key(expected)
+        self.assertEqual(expected, result)
+
+    def _sort_result_based_on_key(self, outputs):
+        """Sort the list based on first value."""
+        for output in outputs:
+            output.sort(key=lambda x: x[0])
+
+
+class BasicOperationTests(PySparkStreamingTestCase):
+
+    def test_map(self):
+        """Basic operation test for DStream.map."""
+        input = [range(1, 5), range(5, 9), range(9, 13)]
+
+        def func(dstream):
+            return dstream.map(str)
+        expected = map(lambda x: map(str, x), input)
+        self._test_func(input, func, expected)
+
+    def test_flatMap(self):
+        """Basic operation test for DStream.faltMap."""
+        input = [range(1, 5), range(5, 9), range(9, 13)]
+
+        def func(dstream):
+            return dstream.flatMap(lambda x: (x, x * 2))
+        expected = map(lambda x: list(chain.from_iterable((map(lambda y: [y, y 
* 2], x)))),
+                       input)
+        self._test_func(input, func, expected)
+
+    def test_filter(self):
+        """Basic operation test for DStream.filter."""
+        input = [range(1, 5), range(5, 9), range(9, 13)]
+
+        def func(dstream):
+            return dstream.filter(lambda x: x % 2 == 0)
+        expected = map(lambda x: filter(lambda y: y % 2 == 0, x), input)
+        self._test_func(input, func, expected)
+
+    def test_count(self):
+        """Basic operation test for DStream.count."""
+        input = [range(5), range(10), range(20)]
+
+        def func(dstream):
+            return dstream.count()
+        expected = map(lambda x: [len(x)], input)
+        self._test_func(input, func, expected)
+
+    def test_reduce(self):
+        """Basic operation test for DStream.reduce."""
+        input = [range(1, 5), range(5, 9), range(9, 13)]
+
+        def func(dstream):
+            return dstream.reduce(operator.add)
+        expected = map(lambda x: [reduce(operator.add, x)], input)
+        self._test_func(input, func, expected)
+
+    def test_reduceByKey(self):
+        """Basic operation test for DStream.reduceByKey."""
+        input = [[("a", 1), ("a", 1), ("b", 1), ("b", 1)],
+                 [("", 1), ("", 1), ("", 1), ("", 1)],
+                 [(1, 1), (1, 1), (2, 1), (2, 1), (3, 1)]]
+
+        def func(dstream):
+            return dstream.reduceByKey(operator.add)
+        expected = [[("a", 2), ("b", 2)], [("", 4)], [(1, 2), (2, 2), (3, 1)]]
+        self._test_func(input, func, expected, sort=True)
+
+    def test_mapValues(self):
+        """Basic operation test for DStream.mapValues."""
+        input = [[("a", 2), ("b", 2), ("c", 1), ("d", 1)],
+                 [("", 4), (1, 1), (2, 2), (3, 3)],
+                 [(1, 1), (2, 1), (3, 1), (4, 1)]]
+
+        def func(dstream):
+            return dstream.mapValues(lambda x: x + 10)
+        expected = [[("a", 12), ("b", 12), ("c", 11), ("d", 11)],
+                    [("", 14), (1, 11), (2, 12), (3, 13)],
+                    [(1, 11), (2, 11), (3, 11), (4, 11)]]
+        self._test_func(input, func, expected, sort=True)
+
+    def test_flatMapValues(self):
+        """Basic operation test for DStream.flatMapValues."""
+        input = [[("a", 2), ("b", 2), ("c", 1), ("d", 1)],
+                 [("", 4), (1, 1), (2, 1), (3, 1)],
+                 [(1, 1), (2, 1), (3, 1), (4, 1)]]
+
+        def func(dstream):
+            return dstream.flatMapValues(lambda x: (x, x + 10))
+        expected = [[("a", 2), ("a", 12), ("b", 2), ("b", 12),
+                     ("c", 1), ("c", 11), ("d", 1), ("d", 11)],
+                    [("", 4), ("", 14), (1, 1), (1, 11), (2, 1), (2, 11), (3, 
1), (3, 11)],
+                    [(1, 1), (1, 11), (2, 1), (2, 11), (3, 1), (3, 11), (4, 
1), (4, 11)]]
+        self._test_func(input, func, expected)
+
+    def test_glom(self):
+        """Basic operation test for DStream.glom."""
+        input = [range(1, 5), range(5, 9), range(9, 13)]
+        rdds = [self.sc.parallelize(r, 2) for r in input]
+
+        def func(dstream):
+            return dstream.glom()
+        expected = [[[1, 2], [3, 4]], [[5, 6], [7, 8]], [[9, 10], [11, 12]]]
+        self._test_func(rdds, func, expected)
+
+    def test_mapPartitions(self):
+        """Basic operation test for DStream.mapPartitions."""
+        input = [range(1, 5), range(5, 9), range(9, 13)]
+        rdds = [self.sc.parallelize(r, 2) for r in input]
+
+        def func(dstream):
+            def f(iterator):
+                yield sum(iterator)
+            return dstream.mapPartitions(f)
+        expected = [[3, 7], [11, 15], [19, 23]]
+        self._test_func(rdds, func, expected)
+
+    def test_countByValue(self):
+        """Basic operation test for DStream.countByValue."""
+        input = [range(1, 5) * 2, range(5, 7) + range(5, 9), ["a", "a", "b", 
""]]
+
+        def func(dstream):
+            return dstream.countByValue()
+        expected = [[4], [4], [3]]
+        self._test_func(input, func, expected)
+
+    def test_groupByKey(self):
+        """Basic operation test for DStream.groupByKey."""
+        input = [[(1, 1), (2, 1), (3, 1), (4, 1)],
+                 [(1, 1), (1, 1), (1, 1), (2, 1), (2, 1), (3, 1)],
+                 [("a", 1), ("a", 1), ("b", 1), ("", 1), ("", 1), ("", 1)]]
+
+        def func(dstream):
+            return dstream.groupByKey().mapValues(list)
+
+        expected = [[(1, [1]), (2, [1]), (3, [1]), (4, [1])],
+                    [(1, [1, 1, 1]), (2, [1, 1]), (3, [1])],
+                    [("a", [1, 1]), ("b", [1]), ("", [1, 1, 1])]]
+        self._test_func(input, func, expected, sort=True)
+
+    def test_combineByKey(self):
+        """Basic operation test for DStream.combineByKey."""
+        input = [[(1, 1), (2, 1), (3, 1), (4, 1)],
+                 [(1, 1), (1, 1), (1, 1), (2, 1), (2, 1), (3, 1)],
+                 [("a", 1), ("a", 1), ("b", 1), ("", 1), ("", 1), ("", 1)]]
+
+        def func(dstream):
+            def add(a, b):
+                return a + str(b)
+            return dstream.combineByKey(str, add, add)
+        expected = [[(1, "1"), (2, "1"), (3, "1"), (4, "1")],
+                    [(1, "111"), (2, "11"), (3, "1")],
+                    [("a", "11"), ("b", "1"), ("", "111")]]
+        self._test_func(input, func, expected, sort=True)
+
+    def test_repartition(self):
+        input = [range(1, 5), range(5, 9)]
+        rdds = [self.sc.parallelize(r, 2) for r in input]
+
+        def func(dstream):
+            return dstream.repartition(1).glom()
+        expected = [[[1, 2, 3, 4]], [[5, 6, 7, 8]]]
+        self._test_func(rdds, func, expected)
+
+    def test_union(self):
+        input1 = [range(3), range(5), range(6)]
+        input2 = [range(3, 6), range(5, 6)]
+
+        def func(d1, d2):
+            return d1.union(d2)
+
+        expected = [range(6), range(6), range(6)]
+        self._test_func(input1, func, expected, input2=input2)
+
+    def test_cogroup(self):
+        input = [[(1, 1), (2, 1), (3, 1)],
+                 [(1, 1), (1, 1), (1, 1), (2, 1)],
+                 [("a", 1), ("a", 1), ("b", 1), ("", 1), ("", 1)]]
+        input2 = [[(1, 2)],
+                  [(4, 1)],
+                  [("a", 1), ("a", 1), ("b", 1), ("", 1), ("", 2)]]
+
+        def func(d1, d2):
+            return d1.cogroup(d2).mapValues(lambda vs: tuple(map(list, vs)))
+
+        expected = [[(1, ([1], [2])), (2, ([1], [])), (3, ([1], []))],
+                    [(1, ([1, 1, 1], [])), (2, ([1], [])), (4, ([], [1]))],
+                    [("a", ([1, 1], [1, 1])), ("b", ([1], [1])), ("", ([1, 1], 
[1, 2]))]]
+        self._test_func(input, func, expected, sort=True, input2=input2)
+
+    def test_join(self):
+        input = [[('a', 1), ('b', 2)]]
+        input2 = [[('b', 3), ('c', 4)]]
+
+        def func(a, b):
+            return a.join(b)
+
+        expected = [[('b', (2, 3))]]
+        self._test_func(input, func, expected, True, input2)
+
+    def test_left_outer_join(self):
+        input = [[('a', 1), ('b', 2)]]
+        input2 = [[('b', 3), ('c', 4)]]
+
+        def func(a, b):
+            return a.leftOuterJoin(b)
+
+        expected = [[('a', (1, None)), ('b', (2, 3))]]
+        self._test_func(input, func, expected, True, input2)
+
+    def test_right_outer_join(self):
+        input = [[('a', 1), ('b', 2)]]
+        input2 = [[('b', 3), ('c', 4)]]
+
+        def func(a, b):
+            return a.rightOuterJoin(b)
+
+        expected = [[('b', (2, 3)), ('c', (None, 4))]]
+        self._test_func(input, func, expected, True, input2)
+
+    def test_full_outer_join(self):
+        input = [[('a', 1), ('b', 2)]]
+        input2 = [[('b', 3), ('c', 4)]]
+
+        def func(a, b):
+            return a.fullOuterJoin(b)
+
+        expected = [[('a', (1, None)), ('b', (2, 3)), ('c', (None, 4))]]
+        self._test_func(input, func, expected, True, input2)
+
+    def test_update_state_by_key(self):
+
+        def updater(vs, s):
+            if not s:
+                s = []
+            s.extend(vs)
+            return s
+
+        input = [[('k', i)] for i in range(5)]
+
+        def func(dstream):
+            return dstream.updateStateByKey(updater)
+
+        expected = [[0], [0, 1], [0, 1, 2], [0, 1, 2, 3], [0, 1, 2, 3, 4]]
+        expected = [[('k', v)] for v in expected]
+        self._test_func(input, func, expected)
+
+
+class WindowFunctionTests(PySparkStreamingTestCase):
+
+    timeout = 20
+
+    def test_window(self):
+        input = [range(1), range(2), range(3), range(4), range(5)]
+
+        def func(dstream):
+            return dstream.window(3, 1).count()
+
+        expected = [[1], [3], [6], [9], [12], [9], [5]]
+        self._test_func(input, func, expected)
+
+    def test_count_by_window(self):
+        input = [range(1), range(2), range(3), range(4), range(5)]
+
+        def func(dstream):
+            return dstream.countByWindow(3, 1)
+
+        expected = [[1], [3], [6], [9], [12], [9], [5]]
+        self._test_func(input, func, expected)
+
+    def test_count_by_window_large(self):
+        input = [range(1), range(2), range(3), range(4), range(5), range(6)]
+
+        def func(dstream):
+            return dstream.countByWindow(5, 1)
+
+        expected = [[1], [3], [6], [10], [15], [20], [18], [15], [11], [6]]
+        self._test_func(input, func, expected)
+
+    def test_count_by_value_and_window(self):
+        input = [range(1), range(2), range(3), range(4), range(5), range(6)]
+
+        def func(dstream):
+            return dstream.countByValueAndWindow(5, 1)
+
+        expected = [[1], [2], [3], [4], [5], [6], [6], [6], [6], [6]]
+        self._test_func(input, func, expected)
+
+    def test_group_by_key_and_window(self):
+        input = [[('a', i)] for i in range(5)]
+
+        def func(dstream):
+            return dstream.groupByKeyAndWindow(3, 1).mapValues(list)
+
+        expected = [[('a', [0])], [('a', [0, 1])], [('a', [0, 1, 2])], [('a', 
[1, 2, 3])],
+                    [('a', [2, 3, 4])], [('a', [3, 4])], [('a', [4])]]
+        self._test_func(input, func, expected)
+
+    def test_reduce_by_invalid_window(self):
+        input1 = [range(3), range(5), range(1), range(6)]
+        d1 = self.ssc.queueStream(input1)
+        self.assertRaises(ValueError, lambda: d1.reduceByKeyAndWindow(None, 
None, 0.1, 0.1))
+        self.assertRaises(ValueError, lambda: d1.reduceByKeyAndWindow(None, 
None, 1, 0.1))
+
+
+class StreamingContextTests(PySparkStreamingTestCase):
+
+    duration = 0.1
+
+    def _add_input_stream(self):
+        inputs = map(lambda x: range(1, x), range(101))
+        stream = self.ssc.queueStream(inputs)
+        self._collect(stream, 1, block=False)
+
+    def test_stop_only_streaming_context(self):
+        self._add_input_stream()
+        self.ssc.start()
+        self.ssc.stop(False)
+        self.assertEqual(len(self.sc.parallelize(range(5), 
5).glom().collect()), 5)
+
+    def test_stop_multiple_times(self):
+        self._add_input_stream()
+        self.ssc.start()
+        self.ssc.stop()
+        self.ssc.stop()
+
+    def test_queue_stream(self):
+        input = [range(i + 1) for i in range(3)]
+        dstream = self.ssc.queueStream(input)
+        result = self._collect(dstream, 3)
+        self.assertEqual(input, result)
+
+    def test_text_file_stream(self):
+        d = tempfile.mkdtemp()
+        self.ssc = StreamingContext(self.sc, self.duration)
+        dstream2 = self.ssc.textFileStream(d).map(int)
+        result = self._collect(dstream2, 2, block=False)
+        self.ssc.start()
+        for name in ('a', 'b'):
+            time.sleep(1)
+            with open(os.path.join(d, name), "w") as f:
+                f.writelines(["%d\n" % i for i in range(10)])
+        self.wait_for(result, 2)
+        self.assertEqual([range(10), range(10)], result)
+
+    def test_union(self):
+        input = [range(i + 1) for i in range(3)]
+        dstream = self.ssc.queueStream(input)
+        dstream2 = self.ssc.queueStream(input)
+        dstream3 = self.ssc.union(dstream, dstream2)
+        result = self._collect(dstream3, 3)
+        expected = [i * 2 for i in input]
+        self.assertEqual(expected, result)
+
+    def test_transform(self):
+        dstream1 = self.ssc.queueStream([[1]])
+        dstream2 = self.ssc.queueStream([[2]])
+        dstream3 = self.ssc.queueStream([[3]])
+
+        def func(rdds):
+            rdd1, rdd2, rdd3 = rdds
+            return rdd2.union(rdd3).union(rdd1)
+
+        dstream = self.ssc.transform([dstream1, dstream2, dstream3], func)
+
+        self.assertEqual([2, 3, 1], self._take(dstream, 3))
+
+
+class CheckpointTests(PySparkStreamingTestCase):
+
+    def setUp(self):
+        pass
+
+    def test_get_or_create(self):
+        inputd = tempfile.mkdtemp()
+        outputd = tempfile.mkdtemp() + "/"
+
+        def updater(vs, s):
+            return sum(vs, s or 0)
+
+        def setup():
+            conf = SparkConf().set("spark.default.parallelism", 1)
+            sc = SparkContext(conf=conf)
+            ssc = StreamingContext(sc, 0.5)
+            dstream = ssc.textFileStream(inputd).map(lambda x: (x, 1))
+            wc = dstream.updateStateByKey(updater)
+            wc.map(lambda x: "%s,%d" % x).saveAsTextFiles(outputd + "test")
+            wc.checkpoint(.5)
+            return ssc
+
+        cpd = tempfile.mkdtemp("test_streaming_cps")
+        self.ssc = ssc = StreamingContext.getOrCreate(cpd, setup)
+        ssc.start()
+
+        def check_output(n):
+            while not os.listdir(outputd):
+                time.sleep(0.1)
+            time.sleep(1)  # make sure mtime is larger than the previous one
+            with open(os.path.join(inputd, str(n)), 'w') as f:
+                f.writelines(["%d\n" % i for i in range(10)])
+
+            while True:
+                p = os.path.join(outputd, max(os.listdir(outputd)))
+                if '_SUCCESS' not in os.listdir(p):
+                    # not finished
+                    time.sleep(0.01)
+                    continue
+                ordd = ssc.sparkContext.textFile(p).map(lambda line: 
line.split(","))
+                d = ordd.values().map(int).collect()
+                if not d:
+                    time.sleep(0.01)
+                    continue
+                self.assertEqual(10, len(d))
+                s = set(d)
+                self.assertEqual(1, len(s))
+                m = s.pop()
+                if n > m:
+                    continue
+                self.assertEqual(n, m)
+                break
+
+        check_output(1)
+        check_output(2)
+        ssc.stop(True, True)
+
+        time.sleep(1)
+        self.ssc = ssc = StreamingContext.getOrCreate(cpd, setup)
+        ssc.start()
+        check_output(3)
+
+
+if __name__ == "__main__":
+    unittest.main()

http://git-wip-us.apache.org/repos/asf/spark/blob/69c67aba/python/pyspark/streaming/util.py
----------------------------------------------------------------------
diff --git a/python/pyspark/streaming/util.py b/python/pyspark/streaming/util.py
new file mode 100644
index 0000000..86ee5aa
--- /dev/null
+++ b/python/pyspark/streaming/util.py
@@ -0,0 +1,128 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements.  See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License.  You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import time
+from datetime import datetime
+import traceback
+
+from pyspark import SparkContext, RDD
+
+
+class TransformFunction(object):
+    """
+    This class wraps a function RDD[X] -> RDD[Y] that was passed to
+    DStream.transform(), allowing it to be called from Java via Py4J's
+    callback server.
+
+    Java calls this function with a sequence of JavaRDDs and this function
+    returns a single JavaRDD pointer back to Java.
+    """
+    _emptyRDD = None
+
+    def __init__(self, ctx, func, *deserializers):
+        self.ctx = ctx
+        self.func = func
+        self.deserializers = deserializers
+
+    def call(self, milliseconds, jrdds):
+        try:
+            if self.ctx is None:
+                self.ctx = SparkContext._active_spark_context
+            if not self.ctx or not self.ctx._jsc:
+                # stopped
+                return
+
+            # extend deserializers with the first one
+            sers = self.deserializers
+            if len(sers) < len(jrdds):
+                sers += (sers[0],) * (len(jrdds) - len(sers))
+
+            rdds = [RDD(jrdd, self.ctx, ser) if jrdd else None
+                    for jrdd, ser in zip(jrdds, sers)]
+            t = datetime.fromtimestamp(milliseconds / 1000.0)
+            r = self.func(t, *rdds)
+            if r:
+                return r._jrdd
+        except Exception:
+            traceback.print_exc()
+
+    def __repr__(self):
+        return "TransformFunction(%s)" % self.func
+
+    class Java:
+        implements = 
['org.apache.spark.streaming.api.python.PythonTransformFunction']
+
+
+class TransformFunctionSerializer(object):
+    """
+    This class implements a serializer for PythonTransformFunction Java
+    objects.
+
+    This is necessary because the Java PythonTransformFunction objects are
+    actually Py4J references to Python objects and thus are not directly
+    serializable. When Java needs to serialize a PythonTransformFunction,
+    it uses this class to invoke Python, which returns the serialized function
+    as a byte array.
+    """
+    def __init__(self, ctx, serializer, gateway=None):
+        self.ctx = ctx
+        self.serializer = serializer
+        self.gateway = gateway or self.ctx._gateway
+        self.gateway.jvm.PythonDStream.registerSerializer(self)
+
+    def dumps(self, id):
+        try:
+            func = self.gateway.gateway_property.pool[id]
+            return bytearray(self.serializer.dumps((func.func, 
func.deserializers)))
+        except Exception:
+            traceback.print_exc()
+
+    def loads(self, bytes):
+        try:
+            f, deserializers = self.serializer.loads(str(bytes))
+            return TransformFunction(self.ctx, f, *deserializers)
+        except Exception:
+            traceback.print_exc()
+
+    def __repr__(self):
+        return "TransformFunctionSerializer(%s)" % self.serializer
+
+    class Java:
+        implements = 
['org.apache.spark.streaming.api.python.PythonTransformFunctionSerializer']
+
+
+def rddToFileName(prefix, suffix, timestamp):
+    """
+    Return string prefix-time(.suffix)
+
+    >>> rddToFileName("spark", None, 12345678910)
+    'spark-12345678910'
+    >>> rddToFileName("spark", "tmp", 12345678910)
+    'spark-12345678910.tmp'
+    """
+    if isinstance(timestamp, datetime):
+        seconds = time.mktime(timestamp.timetuple())
+        timestamp = long(seconds * 1000) + timestamp.microsecond / 1000
+    if suffix is None:
+        return prefix + "-" + str(timestamp)
+    else:
+        return prefix + "-" + str(timestamp) + "." + suffix
+
+
+if __name__ == "__main__":
+    import doctest
+    doctest.testmod()

http://git-wip-us.apache.org/repos/asf/spark/blob/69c67aba/python/run-tests
----------------------------------------------------------------------
diff --git a/python/run-tests b/python/run-tests
index f6a9684..2f98443 100755
--- a/python/run-tests
+++ b/python/run-tests
@@ -81,6 +81,11 @@ function run_mllib_tests() {
     run_test "pyspark/mllib/tests.py"
 }
 
+function run_streaming_tests() {
+    run_test "pyspark/streaming/util.py"
+    run_test "pyspark/streaming/tests.py"
+}
+
 echo "Running PySpark tests. Output is in python/unit-tests.log."
 
 export PYSPARK_PYTHON="python"
@@ -96,6 +101,7 @@ $PYSPARK_PYTHON --version
 run_core_tests
 run_sql_tests
 run_mllib_tests
+run_streaming_tests
 
 # Try to test with PyPy
 if [ $(which pypy) ]; then
@@ -105,6 +111,7 @@ if [ $(which pypy) ]; then
 
     run_core_tests
     run_sql_tests
+    run_streaming_tests
 fi
 
 if [[ $FAILED == 0 ]]; then

http://git-wip-us.apache.org/repos/asf/spark/blob/69c67aba/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaDStreamLike.scala
----------------------------------------------------------------------
diff --git 
a/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaDStreamLike.scala
 
b/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaDStreamLike.scala
index a6184de..2a7004e 100644
--- 
a/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaDStreamLike.scala
+++ 
b/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaDStreamLike.scala
@@ -167,7 +167,7 @@ trait JavaDStreamLike[T, This <: JavaDStreamLike[T, This, 
R], R <: JavaRDDLike[T
     new JavaPairDStream(dstream.flatMap(fn)(cm))(fakeClassTag[K2], 
fakeClassTag[V2])
   }
 
-    /**
+   /**
    * Return a new DStream in which each RDD is generated by applying 
mapPartitions() to each RDDs
    * of this DStream. Applying mapPartitions() to an RDD applies a function to 
each partition
    * of the RDD.

http://git-wip-us.apache.org/repos/asf/spark/blob/69c67aba/streaming/src/main/scala/org/apache/spark/streaming/api/python/PythonDStream.scala
----------------------------------------------------------------------
diff --git 
a/streaming/src/main/scala/org/apache/spark/streaming/api/python/PythonDStream.scala
 
b/streaming/src/main/scala/org/apache/spark/streaming/api/python/PythonDStream.scala
new file mode 100644
index 0000000..213dff6
--- /dev/null
+++ 
b/streaming/src/main/scala/org/apache/spark/streaming/api/python/PythonDStream.scala
@@ -0,0 +1,316 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.streaming.api.python
+
+import java.io.{ObjectInputStream, ObjectOutputStream}
+import java.lang.reflect.Proxy
+import java.util.{ArrayList => JArrayList, List => JList}
+import scala.collection.JavaConversions._
+import scala.collection.JavaConverters._
+import scala.language.existentials
+
+import py4j.GatewayServer
+
+import org.apache.spark.api.java._
+import org.apache.spark.api.python._
+import org.apache.spark.rdd.RDD
+import org.apache.spark.storage.StorageLevel
+import org.apache.spark.streaming.{Interval, Duration, Time}
+import org.apache.spark.streaming.dstream._
+import org.apache.spark.streaming.api.java._
+
+
+/**
+ * Interface for Python callback function which is used to transform RDDs
+ */
+private[python] trait PythonTransformFunction {
+  def call(time: Long, rdds: JList[_]): JavaRDD[Array[Byte]]
+}
+
+/**
+ * Interface for Python Serializer to serialize PythonTransformFunction
+ */
+private[python] trait PythonTransformFunctionSerializer {
+  def dumps(id: String): Array[Byte]
+  def loads(bytes: Array[Byte]): PythonTransformFunction
+}
+
+/**
+ * Wraps a PythonTransformFunction (which is a Python object accessed through 
Py4J)
+ * so that it looks like a Scala function and can be transparently serialized 
and
+ * deserialized by Java.
+ */
+private[python] class TransformFunction(@transient var pfunc: 
PythonTransformFunction)
+  extends function.Function2[JList[JavaRDD[_]], Time, JavaRDD[Array[Byte]]] {
+
+  def apply(rdd: Option[RDD[_]], time: Time): Option[RDD[Array[Byte]]] = {
+    Option(pfunc.call(time.milliseconds, 
List(rdd.map(JavaRDD.fromRDD(_)).orNull).asJava))
+      .map(_.rdd)
+  }
+
+  def apply(rdd: Option[RDD[_]], rdd2: Option[RDD[_]], time: Time): 
Option[RDD[Array[Byte]]] = {
+    val rdds = List(rdd.map(JavaRDD.fromRDD(_)).orNull, 
rdd2.map(JavaRDD.fromRDD(_)).orNull).asJava
+    Option(pfunc.call(time.milliseconds, rdds)).map(_.rdd)
+  }
+
+  // for function.Function2
+  def call(rdds: JList[JavaRDD[_]], time: Time): JavaRDD[Array[Byte]] = {
+    pfunc.call(time.milliseconds, rdds)
+  }
+
+  private def writeObject(out: ObjectOutputStream): Unit = {
+    val bytes = PythonTransformFunctionSerializer.serialize(pfunc)
+    out.writeInt(bytes.length)
+    out.write(bytes)
+  }
+
+  private def readObject(in: ObjectInputStream): Unit = {
+    val length = in.readInt()
+    val bytes = new Array[Byte](length)
+    in.readFully(bytes)
+    pfunc = PythonTransformFunctionSerializer.deserialize(bytes)
+  }
+}
+
+/**
+ * Helpers for PythonTransformFunctionSerializer
+ *
+ * PythonTransformFunctionSerializer is logically a singleton that's happens 
to be
+ * implemented as a Python object.
+ */
+private[python] object PythonTransformFunctionSerializer {
+
+  /**
+   * A serializer in Python, used to serialize PythonTransformFunction
+    */
+  private var serializer: PythonTransformFunctionSerializer = _
+
+  /*
+   * Register a serializer from Python, should be called during initialization
+   */
+  def register(ser: PythonTransformFunctionSerializer): Unit = {
+    serializer = ser
+  }
+
+  def serialize(func: PythonTransformFunction): Array[Byte] = {
+    assert(serializer != null, "Serializer has not been registered!")
+    // get the id of PythonTransformFunction in py4j
+    val h = Proxy.getInvocationHandler(func.asInstanceOf[Proxy])
+    val f = h.getClass().getDeclaredField("id")
+    f.setAccessible(true)
+    val id = f.get(h).asInstanceOf[String]
+    serializer.dumps(id)
+  }
+
+  def deserialize(bytes: Array[Byte]): PythonTransformFunction = {
+    assert(serializer != null, "Serializer has not been registered!")
+    serializer.loads(bytes)
+  }
+}
+
+/**
+ * Helper functions, which are called from Python via Py4J.
+ */
+private[python] object PythonDStream {
+
+  /**
+   * can not access PythonTransformFunctionSerializer.register() via Py4j
+   * Py4JError: PythonTransformFunctionSerializerregister does not exist in 
the JVM
+   */
+  def registerSerializer(ser: PythonTransformFunctionSerializer): Unit = {
+    PythonTransformFunctionSerializer.register(ser)
+  }
+
+  /**
+   * Update the port of callback client to `port`
+   */
+  def updatePythonGatewayPort(gws: GatewayServer, port: Int): Unit = {
+    val cl = gws.getCallbackClient
+    val f = cl.getClass.getDeclaredField("port")
+    f.setAccessible(true)
+    f.setInt(cl, port)
+  }
+
+  /**
+   * helper function for DStream.foreachRDD(),
+   * cannot be `foreachRDD`, it will confusing py4j
+   */
+  def callForeachRDD(jdstream: JavaDStream[Array[Byte]], pfunc: 
PythonTransformFunction) {
+    val func = new TransformFunction((pfunc))
+    jdstream.dstream.foreachRDD((rdd, time) => func(Some(rdd), time))
+  }
+
+  /**
+   * convert list of RDD into queue of RDDs, for ssc.queueStream()
+   */
+  def toRDDQueue(rdds: JArrayList[JavaRDD[Array[Byte]]]): 
java.util.Queue[JavaRDD[Array[Byte]]] = {
+    val queue = new java.util.LinkedList[JavaRDD[Array[Byte]]]
+    rdds.forall(queue.add(_))
+    queue
+  }
+}
+
+/**
+ * Base class for PythonDStream with some common methods
+ */
+private[python] abstract class PythonDStream(
+    parent: DStream[_],
+    @transient pfunc: PythonTransformFunction)
+  extends DStream[Array[Byte]] (parent.ssc) {
+
+  val func = new TransformFunction(pfunc)
+
+  override def dependencies = List(parent)
+
+  override def slideDuration: Duration = parent.slideDuration
+
+  val asJavaDStream  = JavaDStream.fromDStream(this)
+}
+
+/**
+ * Transformed DStream in Python.
+ */
+private[python] class PythonTransformedDStream (
+    parent: DStream[_],
+    @transient pfunc: PythonTransformFunction)
+  extends PythonDStream(parent, pfunc) {
+
+  override def compute(validTime: Time): Option[RDD[Array[Byte]]] = {
+    val rdd = parent.getOrCompute(validTime)
+    if (rdd.isDefined) {
+      func(rdd, validTime)
+    } else {
+      None
+    }
+  }
+}
+
+/**
+ * Transformed from two DStreams in Python.
+ */
+private[python] class PythonTransformed2DStream(
+    parent: DStream[_],
+    parent2: DStream[_],
+    @transient pfunc: PythonTransformFunction)
+  extends DStream[Array[Byte]] (parent.ssc) {
+
+  val func = new TransformFunction(pfunc)
+
+  override def dependencies = List(parent, parent2)
+
+  override def slideDuration: Duration = parent.slideDuration
+
+  override def compute(validTime: Time): Option[RDD[Array[Byte]]] = {
+    val empty: RDD[_] = ssc.sparkContext.emptyRDD
+    val rdd1 = parent.getOrCompute(validTime).getOrElse(empty)
+    val rdd2 = parent2.getOrCompute(validTime).getOrElse(empty)
+    func(Some(rdd1), Some(rdd2), validTime)
+  }
+
+  val asJavaDStream  = JavaDStream.fromDStream(this)
+}
+
+/**
+ * similar to StateDStream
+ */
+private[python] class PythonStateDStream(
+    parent: DStream[Array[Byte]],
+    @transient reduceFunc: PythonTransformFunction)
+  extends PythonDStream(parent, reduceFunc) {
+
+  super.persist(StorageLevel.MEMORY_ONLY)
+  override val mustCheckpoint = true
+
+  override def compute(validTime: Time): Option[RDD[Array[Byte]]] = {
+    val lastState = getOrCompute(validTime - slideDuration)
+    val rdd = parent.getOrCompute(validTime)
+    if (rdd.isDefined) {
+      func(lastState, rdd, validTime)
+    } else {
+      lastState
+    }
+  }
+}
+
+/**
+ * similar to ReducedWindowedDStream
+ */
+private[python] class PythonReducedWindowedDStream(
+    parent: DStream[Array[Byte]],
+    @transient preduceFunc: PythonTransformFunction,
+    @transient pinvReduceFunc: PythonTransformFunction,
+    _windowDuration: Duration,
+    _slideDuration: Duration)
+  extends PythonDStream(parent, preduceFunc) {
+
+  super.persist(StorageLevel.MEMORY_ONLY)
+  override val mustCheckpoint = true
+
+  val invReduceFunc = new TransformFunction(pinvReduceFunc)
+
+  def windowDuration: Duration = _windowDuration
+  override def slideDuration: Duration = _slideDuration
+  override def parentRememberDuration: Duration = rememberDuration + 
windowDuration
+
+  override def compute(validTime: Time): Option[RDD[Array[Byte]]] = {
+    val currentTime = validTime
+    val current = new Interval(currentTime - windowDuration, currentTime)
+    val previous = current - slideDuration
+
+    //  _____________________________
+    // |  previous window   _________|___________________
+    // |___________________|       current window        |  --------------> 
Time
+    //                     |_____________________________|
+    //
+    // |________ _________|          |________ _________|
+    //          |                             |
+    //          V                             V
+    //       old RDDs                     new RDDs
+    //
+    val previousRDD = getOrCompute(previous.endTime)
+
+    // for small window, reduce once will be better than twice
+    if (pinvReduceFunc != null && previousRDD.isDefined
+        && windowDuration >= slideDuration * 5) {
+
+      // subtract the values from old RDDs
+      val oldRDDs = parent.slice(previous.beginTime + parent.slideDuration, 
current.beginTime)
+      val subtracted = if (oldRDDs.size > 0) {
+        invReduceFunc(previousRDD, Some(ssc.sc.union(oldRDDs)), validTime)
+      } else {
+        previousRDD
+      }
+
+      // add the RDDs of the reduced values in "new time steps"
+      val newRDDs = parent.slice(previous.endTime + parent.slideDuration, 
current.endTime)
+      if (newRDDs.size > 0) {
+        func(subtracted, Some(ssc.sc.union(newRDDs)), validTime)
+      } else {
+        subtracted
+      }
+    } else {
+      // Get the RDDs of the reduced values in current window
+      val currentRDDs = parent.slice(current.beginTime + parent.slideDuration, 
current.endTime)
+      if (currentRDDs.size > 0) {
+        func(None, Some(ssc.sc.union(currentRDDs)), validTime)
+      } else {
+        None
+      }
+    }
+  }
+}


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to