Repository: spark Updated Branches: refs/heads/master 7a3f589ef -> 69c67abaa
http://git-wip-us.apache.org/repos/asf/spark/blob/69c67aba/python/pyspark/streaming/tests.py ---------------------------------------------------------------------- diff --git a/python/pyspark/streaming/tests.py b/python/pyspark/streaming/tests.py new file mode 100644 index 0000000..a8d876d --- /dev/null +++ b/python/pyspark/streaming/tests.py @@ -0,0 +1,545 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import os +from itertools import chain +import time +import operator +import unittest +import tempfile + +from pyspark.context import SparkConf, SparkContext, RDD +from pyspark.streaming.context import StreamingContext + + +class PySparkStreamingTestCase(unittest.TestCase): + + timeout = 10 # seconds + duration = 1 + + def setUp(self): + class_name = self.__class__.__name__ + conf = SparkConf().set("spark.default.parallelism", 1) + self.sc = SparkContext(appName=class_name, conf=conf) + self.sc.setCheckpointDir("/tmp") + # TODO: decrease duration to speed up tests + self.ssc = StreamingContext(self.sc, self.duration) + + def tearDown(self): + self.ssc.stop() + + def wait_for(self, result, n): + start_time = time.time() + while len(result) < n and time.time() - start_time < self.timeout: + time.sleep(0.01) + if len(result) < n: + print "timeout after", self.timeout + + def _take(self, dstream, n): + """ + Return the first `n` elements in the stream (will start and stop). + """ + results = [] + + def take(_, rdd): + if rdd and len(results) < n: + results.extend(rdd.take(n - len(results))) + + dstream.foreachRDD(take) + + self.ssc.start() + self.wait_for(results, n) + return results + + def _collect(self, dstream, n, block=True): + """ + Collect each RDDs into the returned list. + + :return: list, which will have the collected items. + """ + result = [] + + def get_output(_, rdd): + if rdd and len(result) < n: + r = rdd.collect() + if r: + result.append(r) + + dstream.foreachRDD(get_output) + + if not block: + return result + + self.ssc.start() + self.wait_for(result, n) + return result + + def _test_func(self, input, func, expected, sort=False, input2=None): + """ + @param input: dataset for the test. This should be list of lists. + @param func: wrapped function. This function should return PythonDStream object. + @param expected: expected output for this testcase. + """ + if not isinstance(input[0], RDD): + input = [self.sc.parallelize(d, 1) for d in input] + input_stream = self.ssc.queueStream(input) + if input2 and not isinstance(input2[0], RDD): + input2 = [self.sc.parallelize(d, 1) for d in input2] + input_stream2 = self.ssc.queueStream(input2) if input2 is not None else None + + # Apply test function to stream. + if input2: + stream = func(input_stream, input_stream2) + else: + stream = func(input_stream) + + result = self._collect(stream, len(expected)) + if sort: + self._sort_result_based_on_key(result) + self._sort_result_based_on_key(expected) + self.assertEqual(expected, result) + + def _sort_result_based_on_key(self, outputs): + """Sort the list based on first value.""" + for output in outputs: + output.sort(key=lambda x: x[0]) + + +class BasicOperationTests(PySparkStreamingTestCase): + + def test_map(self): + """Basic operation test for DStream.map.""" + input = [range(1, 5), range(5, 9), range(9, 13)] + + def func(dstream): + return dstream.map(str) + expected = map(lambda x: map(str, x), input) + self._test_func(input, func, expected) + + def test_flatMap(self): + """Basic operation test for DStream.faltMap.""" + input = [range(1, 5), range(5, 9), range(9, 13)] + + def func(dstream): + return dstream.flatMap(lambda x: (x, x * 2)) + expected = map(lambda x: list(chain.from_iterable((map(lambda y: [y, y * 2], x)))), + input) + self._test_func(input, func, expected) + + def test_filter(self): + """Basic operation test for DStream.filter.""" + input = [range(1, 5), range(5, 9), range(9, 13)] + + def func(dstream): + return dstream.filter(lambda x: x % 2 == 0) + expected = map(lambda x: filter(lambda y: y % 2 == 0, x), input) + self._test_func(input, func, expected) + + def test_count(self): + """Basic operation test for DStream.count.""" + input = [range(5), range(10), range(20)] + + def func(dstream): + return dstream.count() + expected = map(lambda x: [len(x)], input) + self._test_func(input, func, expected) + + def test_reduce(self): + """Basic operation test for DStream.reduce.""" + input = [range(1, 5), range(5, 9), range(9, 13)] + + def func(dstream): + return dstream.reduce(operator.add) + expected = map(lambda x: [reduce(operator.add, x)], input) + self._test_func(input, func, expected) + + def test_reduceByKey(self): + """Basic operation test for DStream.reduceByKey.""" + input = [[("a", 1), ("a", 1), ("b", 1), ("b", 1)], + [("", 1), ("", 1), ("", 1), ("", 1)], + [(1, 1), (1, 1), (2, 1), (2, 1), (3, 1)]] + + def func(dstream): + return dstream.reduceByKey(operator.add) + expected = [[("a", 2), ("b", 2)], [("", 4)], [(1, 2), (2, 2), (3, 1)]] + self._test_func(input, func, expected, sort=True) + + def test_mapValues(self): + """Basic operation test for DStream.mapValues.""" + input = [[("a", 2), ("b", 2), ("c", 1), ("d", 1)], + [("", 4), (1, 1), (2, 2), (3, 3)], + [(1, 1), (2, 1), (3, 1), (4, 1)]] + + def func(dstream): + return dstream.mapValues(lambda x: x + 10) + expected = [[("a", 12), ("b", 12), ("c", 11), ("d", 11)], + [("", 14), (1, 11), (2, 12), (3, 13)], + [(1, 11), (2, 11), (3, 11), (4, 11)]] + self._test_func(input, func, expected, sort=True) + + def test_flatMapValues(self): + """Basic operation test for DStream.flatMapValues.""" + input = [[("a", 2), ("b", 2), ("c", 1), ("d", 1)], + [("", 4), (1, 1), (2, 1), (3, 1)], + [(1, 1), (2, 1), (3, 1), (4, 1)]] + + def func(dstream): + return dstream.flatMapValues(lambda x: (x, x + 10)) + expected = [[("a", 2), ("a", 12), ("b", 2), ("b", 12), + ("c", 1), ("c", 11), ("d", 1), ("d", 11)], + [("", 4), ("", 14), (1, 1), (1, 11), (2, 1), (2, 11), (3, 1), (3, 11)], + [(1, 1), (1, 11), (2, 1), (2, 11), (3, 1), (3, 11), (4, 1), (4, 11)]] + self._test_func(input, func, expected) + + def test_glom(self): + """Basic operation test for DStream.glom.""" + input = [range(1, 5), range(5, 9), range(9, 13)] + rdds = [self.sc.parallelize(r, 2) for r in input] + + def func(dstream): + return dstream.glom() + expected = [[[1, 2], [3, 4]], [[5, 6], [7, 8]], [[9, 10], [11, 12]]] + self._test_func(rdds, func, expected) + + def test_mapPartitions(self): + """Basic operation test for DStream.mapPartitions.""" + input = [range(1, 5), range(5, 9), range(9, 13)] + rdds = [self.sc.parallelize(r, 2) for r in input] + + def func(dstream): + def f(iterator): + yield sum(iterator) + return dstream.mapPartitions(f) + expected = [[3, 7], [11, 15], [19, 23]] + self._test_func(rdds, func, expected) + + def test_countByValue(self): + """Basic operation test for DStream.countByValue.""" + input = [range(1, 5) * 2, range(5, 7) + range(5, 9), ["a", "a", "b", ""]] + + def func(dstream): + return dstream.countByValue() + expected = [[4], [4], [3]] + self._test_func(input, func, expected) + + def test_groupByKey(self): + """Basic operation test for DStream.groupByKey.""" + input = [[(1, 1), (2, 1), (3, 1), (4, 1)], + [(1, 1), (1, 1), (1, 1), (2, 1), (2, 1), (3, 1)], + [("a", 1), ("a", 1), ("b", 1), ("", 1), ("", 1), ("", 1)]] + + def func(dstream): + return dstream.groupByKey().mapValues(list) + + expected = [[(1, [1]), (2, [1]), (3, [1]), (4, [1])], + [(1, [1, 1, 1]), (2, [1, 1]), (3, [1])], + [("a", [1, 1]), ("b", [1]), ("", [1, 1, 1])]] + self._test_func(input, func, expected, sort=True) + + def test_combineByKey(self): + """Basic operation test for DStream.combineByKey.""" + input = [[(1, 1), (2, 1), (3, 1), (4, 1)], + [(1, 1), (1, 1), (1, 1), (2, 1), (2, 1), (3, 1)], + [("a", 1), ("a", 1), ("b", 1), ("", 1), ("", 1), ("", 1)]] + + def func(dstream): + def add(a, b): + return a + str(b) + return dstream.combineByKey(str, add, add) + expected = [[(1, "1"), (2, "1"), (3, "1"), (4, "1")], + [(1, "111"), (2, "11"), (3, "1")], + [("a", "11"), ("b", "1"), ("", "111")]] + self._test_func(input, func, expected, sort=True) + + def test_repartition(self): + input = [range(1, 5), range(5, 9)] + rdds = [self.sc.parallelize(r, 2) for r in input] + + def func(dstream): + return dstream.repartition(1).glom() + expected = [[[1, 2, 3, 4]], [[5, 6, 7, 8]]] + self._test_func(rdds, func, expected) + + def test_union(self): + input1 = [range(3), range(5), range(6)] + input2 = [range(3, 6), range(5, 6)] + + def func(d1, d2): + return d1.union(d2) + + expected = [range(6), range(6), range(6)] + self._test_func(input1, func, expected, input2=input2) + + def test_cogroup(self): + input = [[(1, 1), (2, 1), (3, 1)], + [(1, 1), (1, 1), (1, 1), (2, 1)], + [("a", 1), ("a", 1), ("b", 1), ("", 1), ("", 1)]] + input2 = [[(1, 2)], + [(4, 1)], + [("a", 1), ("a", 1), ("b", 1), ("", 1), ("", 2)]] + + def func(d1, d2): + return d1.cogroup(d2).mapValues(lambda vs: tuple(map(list, vs))) + + expected = [[(1, ([1], [2])), (2, ([1], [])), (3, ([1], []))], + [(1, ([1, 1, 1], [])), (2, ([1], [])), (4, ([], [1]))], + [("a", ([1, 1], [1, 1])), ("b", ([1], [1])), ("", ([1, 1], [1, 2]))]] + self._test_func(input, func, expected, sort=True, input2=input2) + + def test_join(self): + input = [[('a', 1), ('b', 2)]] + input2 = [[('b', 3), ('c', 4)]] + + def func(a, b): + return a.join(b) + + expected = [[('b', (2, 3))]] + self._test_func(input, func, expected, True, input2) + + def test_left_outer_join(self): + input = [[('a', 1), ('b', 2)]] + input2 = [[('b', 3), ('c', 4)]] + + def func(a, b): + return a.leftOuterJoin(b) + + expected = [[('a', (1, None)), ('b', (2, 3))]] + self._test_func(input, func, expected, True, input2) + + def test_right_outer_join(self): + input = [[('a', 1), ('b', 2)]] + input2 = [[('b', 3), ('c', 4)]] + + def func(a, b): + return a.rightOuterJoin(b) + + expected = [[('b', (2, 3)), ('c', (None, 4))]] + self._test_func(input, func, expected, True, input2) + + def test_full_outer_join(self): + input = [[('a', 1), ('b', 2)]] + input2 = [[('b', 3), ('c', 4)]] + + def func(a, b): + return a.fullOuterJoin(b) + + expected = [[('a', (1, None)), ('b', (2, 3)), ('c', (None, 4))]] + self._test_func(input, func, expected, True, input2) + + def test_update_state_by_key(self): + + def updater(vs, s): + if not s: + s = [] + s.extend(vs) + return s + + input = [[('k', i)] for i in range(5)] + + def func(dstream): + return dstream.updateStateByKey(updater) + + expected = [[0], [0, 1], [0, 1, 2], [0, 1, 2, 3], [0, 1, 2, 3, 4]] + expected = [[('k', v)] for v in expected] + self._test_func(input, func, expected) + + +class WindowFunctionTests(PySparkStreamingTestCase): + + timeout = 20 + + def test_window(self): + input = [range(1), range(2), range(3), range(4), range(5)] + + def func(dstream): + return dstream.window(3, 1).count() + + expected = [[1], [3], [6], [9], [12], [9], [5]] + self._test_func(input, func, expected) + + def test_count_by_window(self): + input = [range(1), range(2), range(3), range(4), range(5)] + + def func(dstream): + return dstream.countByWindow(3, 1) + + expected = [[1], [3], [6], [9], [12], [9], [5]] + self._test_func(input, func, expected) + + def test_count_by_window_large(self): + input = [range(1), range(2), range(3), range(4), range(5), range(6)] + + def func(dstream): + return dstream.countByWindow(5, 1) + + expected = [[1], [3], [6], [10], [15], [20], [18], [15], [11], [6]] + self._test_func(input, func, expected) + + def test_count_by_value_and_window(self): + input = [range(1), range(2), range(3), range(4), range(5), range(6)] + + def func(dstream): + return dstream.countByValueAndWindow(5, 1) + + expected = [[1], [2], [3], [4], [5], [6], [6], [6], [6], [6]] + self._test_func(input, func, expected) + + def test_group_by_key_and_window(self): + input = [[('a', i)] for i in range(5)] + + def func(dstream): + return dstream.groupByKeyAndWindow(3, 1).mapValues(list) + + expected = [[('a', [0])], [('a', [0, 1])], [('a', [0, 1, 2])], [('a', [1, 2, 3])], + [('a', [2, 3, 4])], [('a', [3, 4])], [('a', [4])]] + self._test_func(input, func, expected) + + def test_reduce_by_invalid_window(self): + input1 = [range(3), range(5), range(1), range(6)] + d1 = self.ssc.queueStream(input1) + self.assertRaises(ValueError, lambda: d1.reduceByKeyAndWindow(None, None, 0.1, 0.1)) + self.assertRaises(ValueError, lambda: d1.reduceByKeyAndWindow(None, None, 1, 0.1)) + + +class StreamingContextTests(PySparkStreamingTestCase): + + duration = 0.1 + + def _add_input_stream(self): + inputs = map(lambda x: range(1, x), range(101)) + stream = self.ssc.queueStream(inputs) + self._collect(stream, 1, block=False) + + def test_stop_only_streaming_context(self): + self._add_input_stream() + self.ssc.start() + self.ssc.stop(False) + self.assertEqual(len(self.sc.parallelize(range(5), 5).glom().collect()), 5) + + def test_stop_multiple_times(self): + self._add_input_stream() + self.ssc.start() + self.ssc.stop() + self.ssc.stop() + + def test_queue_stream(self): + input = [range(i + 1) for i in range(3)] + dstream = self.ssc.queueStream(input) + result = self._collect(dstream, 3) + self.assertEqual(input, result) + + def test_text_file_stream(self): + d = tempfile.mkdtemp() + self.ssc = StreamingContext(self.sc, self.duration) + dstream2 = self.ssc.textFileStream(d).map(int) + result = self._collect(dstream2, 2, block=False) + self.ssc.start() + for name in ('a', 'b'): + time.sleep(1) + with open(os.path.join(d, name), "w") as f: + f.writelines(["%d\n" % i for i in range(10)]) + self.wait_for(result, 2) + self.assertEqual([range(10), range(10)], result) + + def test_union(self): + input = [range(i + 1) for i in range(3)] + dstream = self.ssc.queueStream(input) + dstream2 = self.ssc.queueStream(input) + dstream3 = self.ssc.union(dstream, dstream2) + result = self._collect(dstream3, 3) + expected = [i * 2 for i in input] + self.assertEqual(expected, result) + + def test_transform(self): + dstream1 = self.ssc.queueStream([[1]]) + dstream2 = self.ssc.queueStream([[2]]) + dstream3 = self.ssc.queueStream([[3]]) + + def func(rdds): + rdd1, rdd2, rdd3 = rdds + return rdd2.union(rdd3).union(rdd1) + + dstream = self.ssc.transform([dstream1, dstream2, dstream3], func) + + self.assertEqual([2, 3, 1], self._take(dstream, 3)) + + +class CheckpointTests(PySparkStreamingTestCase): + + def setUp(self): + pass + + def test_get_or_create(self): + inputd = tempfile.mkdtemp() + outputd = tempfile.mkdtemp() + "/" + + def updater(vs, s): + return sum(vs, s or 0) + + def setup(): + conf = SparkConf().set("spark.default.parallelism", 1) + sc = SparkContext(conf=conf) + ssc = StreamingContext(sc, 0.5) + dstream = ssc.textFileStream(inputd).map(lambda x: (x, 1)) + wc = dstream.updateStateByKey(updater) + wc.map(lambda x: "%s,%d" % x).saveAsTextFiles(outputd + "test") + wc.checkpoint(.5) + return ssc + + cpd = tempfile.mkdtemp("test_streaming_cps") + self.ssc = ssc = StreamingContext.getOrCreate(cpd, setup) + ssc.start() + + def check_output(n): + while not os.listdir(outputd): + time.sleep(0.1) + time.sleep(1) # make sure mtime is larger than the previous one + with open(os.path.join(inputd, str(n)), 'w') as f: + f.writelines(["%d\n" % i for i in range(10)]) + + while True: + p = os.path.join(outputd, max(os.listdir(outputd))) + if '_SUCCESS' not in os.listdir(p): + # not finished + time.sleep(0.01) + continue + ordd = ssc.sparkContext.textFile(p).map(lambda line: line.split(",")) + d = ordd.values().map(int).collect() + if not d: + time.sleep(0.01) + continue + self.assertEqual(10, len(d)) + s = set(d) + self.assertEqual(1, len(s)) + m = s.pop() + if n > m: + continue + self.assertEqual(n, m) + break + + check_output(1) + check_output(2) + ssc.stop(True, True) + + time.sleep(1) + self.ssc = ssc = StreamingContext.getOrCreate(cpd, setup) + ssc.start() + check_output(3) + + +if __name__ == "__main__": + unittest.main() http://git-wip-us.apache.org/repos/asf/spark/blob/69c67aba/python/pyspark/streaming/util.py ---------------------------------------------------------------------- diff --git a/python/pyspark/streaming/util.py b/python/pyspark/streaming/util.py new file mode 100644 index 0000000..86ee5aa --- /dev/null +++ b/python/pyspark/streaming/util.py @@ -0,0 +1,128 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import time +from datetime import datetime +import traceback + +from pyspark import SparkContext, RDD + + +class TransformFunction(object): + """ + This class wraps a function RDD[X] -> RDD[Y] that was passed to + DStream.transform(), allowing it to be called from Java via Py4J's + callback server. + + Java calls this function with a sequence of JavaRDDs and this function + returns a single JavaRDD pointer back to Java. + """ + _emptyRDD = None + + def __init__(self, ctx, func, *deserializers): + self.ctx = ctx + self.func = func + self.deserializers = deserializers + + def call(self, milliseconds, jrdds): + try: + if self.ctx is None: + self.ctx = SparkContext._active_spark_context + if not self.ctx or not self.ctx._jsc: + # stopped + return + + # extend deserializers with the first one + sers = self.deserializers + if len(sers) < len(jrdds): + sers += (sers[0],) * (len(jrdds) - len(sers)) + + rdds = [RDD(jrdd, self.ctx, ser) if jrdd else None + for jrdd, ser in zip(jrdds, sers)] + t = datetime.fromtimestamp(milliseconds / 1000.0) + r = self.func(t, *rdds) + if r: + return r._jrdd + except Exception: + traceback.print_exc() + + def __repr__(self): + return "TransformFunction(%s)" % self.func + + class Java: + implements = ['org.apache.spark.streaming.api.python.PythonTransformFunction'] + + +class TransformFunctionSerializer(object): + """ + This class implements a serializer for PythonTransformFunction Java + objects. + + This is necessary because the Java PythonTransformFunction objects are + actually Py4J references to Python objects and thus are not directly + serializable. When Java needs to serialize a PythonTransformFunction, + it uses this class to invoke Python, which returns the serialized function + as a byte array. + """ + def __init__(self, ctx, serializer, gateway=None): + self.ctx = ctx + self.serializer = serializer + self.gateway = gateway or self.ctx._gateway + self.gateway.jvm.PythonDStream.registerSerializer(self) + + def dumps(self, id): + try: + func = self.gateway.gateway_property.pool[id] + return bytearray(self.serializer.dumps((func.func, func.deserializers))) + except Exception: + traceback.print_exc() + + def loads(self, bytes): + try: + f, deserializers = self.serializer.loads(str(bytes)) + return TransformFunction(self.ctx, f, *deserializers) + except Exception: + traceback.print_exc() + + def __repr__(self): + return "TransformFunctionSerializer(%s)" % self.serializer + + class Java: + implements = ['org.apache.spark.streaming.api.python.PythonTransformFunctionSerializer'] + + +def rddToFileName(prefix, suffix, timestamp): + """ + Return string prefix-time(.suffix) + + >>> rddToFileName("spark", None, 12345678910) + 'spark-12345678910' + >>> rddToFileName("spark", "tmp", 12345678910) + 'spark-12345678910.tmp' + """ + if isinstance(timestamp, datetime): + seconds = time.mktime(timestamp.timetuple()) + timestamp = long(seconds * 1000) + timestamp.microsecond / 1000 + if suffix is None: + return prefix + "-" + str(timestamp) + else: + return prefix + "-" + str(timestamp) + "." + suffix + + +if __name__ == "__main__": + import doctest + doctest.testmod() http://git-wip-us.apache.org/repos/asf/spark/blob/69c67aba/python/run-tests ---------------------------------------------------------------------- diff --git a/python/run-tests b/python/run-tests index f6a9684..2f98443 100755 --- a/python/run-tests +++ b/python/run-tests @@ -81,6 +81,11 @@ function run_mllib_tests() { run_test "pyspark/mllib/tests.py" } +function run_streaming_tests() { + run_test "pyspark/streaming/util.py" + run_test "pyspark/streaming/tests.py" +} + echo "Running PySpark tests. Output is in python/unit-tests.log." export PYSPARK_PYTHON="python" @@ -96,6 +101,7 @@ $PYSPARK_PYTHON --version run_core_tests run_sql_tests run_mllib_tests +run_streaming_tests # Try to test with PyPy if [ $(which pypy) ]; then @@ -105,6 +111,7 @@ if [ $(which pypy) ]; then run_core_tests run_sql_tests + run_streaming_tests fi if [[ $FAILED == 0 ]]; then http://git-wip-us.apache.org/repos/asf/spark/blob/69c67aba/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaDStreamLike.scala ---------------------------------------------------------------------- diff --git a/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaDStreamLike.scala b/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaDStreamLike.scala index a6184de..2a7004e 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaDStreamLike.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaDStreamLike.scala @@ -167,7 +167,7 @@ trait JavaDStreamLike[T, This <: JavaDStreamLike[T, This, R], R <: JavaRDDLike[T new JavaPairDStream(dstream.flatMap(fn)(cm))(fakeClassTag[K2], fakeClassTag[V2]) } - /** + /** * Return a new DStream in which each RDD is generated by applying mapPartitions() to each RDDs * of this DStream. Applying mapPartitions() to an RDD applies a function to each partition * of the RDD. http://git-wip-us.apache.org/repos/asf/spark/blob/69c67aba/streaming/src/main/scala/org/apache/spark/streaming/api/python/PythonDStream.scala ---------------------------------------------------------------------- diff --git a/streaming/src/main/scala/org/apache/spark/streaming/api/python/PythonDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/api/python/PythonDStream.scala new file mode 100644 index 0000000..213dff6 --- /dev/null +++ b/streaming/src/main/scala/org/apache/spark/streaming/api/python/PythonDStream.scala @@ -0,0 +1,316 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.streaming.api.python + +import java.io.{ObjectInputStream, ObjectOutputStream} +import java.lang.reflect.Proxy +import java.util.{ArrayList => JArrayList, List => JList} +import scala.collection.JavaConversions._ +import scala.collection.JavaConverters._ +import scala.language.existentials + +import py4j.GatewayServer + +import org.apache.spark.api.java._ +import org.apache.spark.api.python._ +import org.apache.spark.rdd.RDD +import org.apache.spark.storage.StorageLevel +import org.apache.spark.streaming.{Interval, Duration, Time} +import org.apache.spark.streaming.dstream._ +import org.apache.spark.streaming.api.java._ + + +/** + * Interface for Python callback function which is used to transform RDDs + */ +private[python] trait PythonTransformFunction { + def call(time: Long, rdds: JList[_]): JavaRDD[Array[Byte]] +} + +/** + * Interface for Python Serializer to serialize PythonTransformFunction + */ +private[python] trait PythonTransformFunctionSerializer { + def dumps(id: String): Array[Byte] + def loads(bytes: Array[Byte]): PythonTransformFunction +} + +/** + * Wraps a PythonTransformFunction (which is a Python object accessed through Py4J) + * so that it looks like a Scala function and can be transparently serialized and + * deserialized by Java. + */ +private[python] class TransformFunction(@transient var pfunc: PythonTransformFunction) + extends function.Function2[JList[JavaRDD[_]], Time, JavaRDD[Array[Byte]]] { + + def apply(rdd: Option[RDD[_]], time: Time): Option[RDD[Array[Byte]]] = { + Option(pfunc.call(time.milliseconds, List(rdd.map(JavaRDD.fromRDD(_)).orNull).asJava)) + .map(_.rdd) + } + + def apply(rdd: Option[RDD[_]], rdd2: Option[RDD[_]], time: Time): Option[RDD[Array[Byte]]] = { + val rdds = List(rdd.map(JavaRDD.fromRDD(_)).orNull, rdd2.map(JavaRDD.fromRDD(_)).orNull).asJava + Option(pfunc.call(time.milliseconds, rdds)).map(_.rdd) + } + + // for function.Function2 + def call(rdds: JList[JavaRDD[_]], time: Time): JavaRDD[Array[Byte]] = { + pfunc.call(time.milliseconds, rdds) + } + + private def writeObject(out: ObjectOutputStream): Unit = { + val bytes = PythonTransformFunctionSerializer.serialize(pfunc) + out.writeInt(bytes.length) + out.write(bytes) + } + + private def readObject(in: ObjectInputStream): Unit = { + val length = in.readInt() + val bytes = new Array[Byte](length) + in.readFully(bytes) + pfunc = PythonTransformFunctionSerializer.deserialize(bytes) + } +} + +/** + * Helpers for PythonTransformFunctionSerializer + * + * PythonTransformFunctionSerializer is logically a singleton that's happens to be + * implemented as a Python object. + */ +private[python] object PythonTransformFunctionSerializer { + + /** + * A serializer in Python, used to serialize PythonTransformFunction + */ + private var serializer: PythonTransformFunctionSerializer = _ + + /* + * Register a serializer from Python, should be called during initialization + */ + def register(ser: PythonTransformFunctionSerializer): Unit = { + serializer = ser + } + + def serialize(func: PythonTransformFunction): Array[Byte] = { + assert(serializer != null, "Serializer has not been registered!") + // get the id of PythonTransformFunction in py4j + val h = Proxy.getInvocationHandler(func.asInstanceOf[Proxy]) + val f = h.getClass().getDeclaredField("id") + f.setAccessible(true) + val id = f.get(h).asInstanceOf[String] + serializer.dumps(id) + } + + def deserialize(bytes: Array[Byte]): PythonTransformFunction = { + assert(serializer != null, "Serializer has not been registered!") + serializer.loads(bytes) + } +} + +/** + * Helper functions, which are called from Python via Py4J. + */ +private[python] object PythonDStream { + + /** + * can not access PythonTransformFunctionSerializer.register() via Py4j + * Py4JError: PythonTransformFunctionSerializerregister does not exist in the JVM + */ + def registerSerializer(ser: PythonTransformFunctionSerializer): Unit = { + PythonTransformFunctionSerializer.register(ser) + } + + /** + * Update the port of callback client to `port` + */ + def updatePythonGatewayPort(gws: GatewayServer, port: Int): Unit = { + val cl = gws.getCallbackClient + val f = cl.getClass.getDeclaredField("port") + f.setAccessible(true) + f.setInt(cl, port) + } + + /** + * helper function for DStream.foreachRDD(), + * cannot be `foreachRDD`, it will confusing py4j + */ + def callForeachRDD(jdstream: JavaDStream[Array[Byte]], pfunc: PythonTransformFunction) { + val func = new TransformFunction((pfunc)) + jdstream.dstream.foreachRDD((rdd, time) => func(Some(rdd), time)) + } + + /** + * convert list of RDD into queue of RDDs, for ssc.queueStream() + */ + def toRDDQueue(rdds: JArrayList[JavaRDD[Array[Byte]]]): java.util.Queue[JavaRDD[Array[Byte]]] = { + val queue = new java.util.LinkedList[JavaRDD[Array[Byte]]] + rdds.forall(queue.add(_)) + queue + } +} + +/** + * Base class for PythonDStream with some common methods + */ +private[python] abstract class PythonDStream( + parent: DStream[_], + @transient pfunc: PythonTransformFunction) + extends DStream[Array[Byte]] (parent.ssc) { + + val func = new TransformFunction(pfunc) + + override def dependencies = List(parent) + + override def slideDuration: Duration = parent.slideDuration + + val asJavaDStream = JavaDStream.fromDStream(this) +} + +/** + * Transformed DStream in Python. + */ +private[python] class PythonTransformedDStream ( + parent: DStream[_], + @transient pfunc: PythonTransformFunction) + extends PythonDStream(parent, pfunc) { + + override def compute(validTime: Time): Option[RDD[Array[Byte]]] = { + val rdd = parent.getOrCompute(validTime) + if (rdd.isDefined) { + func(rdd, validTime) + } else { + None + } + } +} + +/** + * Transformed from two DStreams in Python. + */ +private[python] class PythonTransformed2DStream( + parent: DStream[_], + parent2: DStream[_], + @transient pfunc: PythonTransformFunction) + extends DStream[Array[Byte]] (parent.ssc) { + + val func = new TransformFunction(pfunc) + + override def dependencies = List(parent, parent2) + + override def slideDuration: Duration = parent.slideDuration + + override def compute(validTime: Time): Option[RDD[Array[Byte]]] = { + val empty: RDD[_] = ssc.sparkContext.emptyRDD + val rdd1 = parent.getOrCompute(validTime).getOrElse(empty) + val rdd2 = parent2.getOrCompute(validTime).getOrElse(empty) + func(Some(rdd1), Some(rdd2), validTime) + } + + val asJavaDStream = JavaDStream.fromDStream(this) +} + +/** + * similar to StateDStream + */ +private[python] class PythonStateDStream( + parent: DStream[Array[Byte]], + @transient reduceFunc: PythonTransformFunction) + extends PythonDStream(parent, reduceFunc) { + + super.persist(StorageLevel.MEMORY_ONLY) + override val mustCheckpoint = true + + override def compute(validTime: Time): Option[RDD[Array[Byte]]] = { + val lastState = getOrCompute(validTime - slideDuration) + val rdd = parent.getOrCompute(validTime) + if (rdd.isDefined) { + func(lastState, rdd, validTime) + } else { + lastState + } + } +} + +/** + * similar to ReducedWindowedDStream + */ +private[python] class PythonReducedWindowedDStream( + parent: DStream[Array[Byte]], + @transient preduceFunc: PythonTransformFunction, + @transient pinvReduceFunc: PythonTransformFunction, + _windowDuration: Duration, + _slideDuration: Duration) + extends PythonDStream(parent, preduceFunc) { + + super.persist(StorageLevel.MEMORY_ONLY) + override val mustCheckpoint = true + + val invReduceFunc = new TransformFunction(pinvReduceFunc) + + def windowDuration: Duration = _windowDuration + override def slideDuration: Duration = _slideDuration + override def parentRememberDuration: Duration = rememberDuration + windowDuration + + override def compute(validTime: Time): Option[RDD[Array[Byte]]] = { + val currentTime = validTime + val current = new Interval(currentTime - windowDuration, currentTime) + val previous = current - slideDuration + + // _____________________________ + // | previous window _________|___________________ + // |___________________| current window | --------------> Time + // |_____________________________| + // + // |________ _________| |________ _________| + // | | + // V V + // old RDDs new RDDs + // + val previousRDD = getOrCompute(previous.endTime) + + // for small window, reduce once will be better than twice + if (pinvReduceFunc != null && previousRDD.isDefined + && windowDuration >= slideDuration * 5) { + + // subtract the values from old RDDs + val oldRDDs = parent.slice(previous.beginTime + parent.slideDuration, current.beginTime) + val subtracted = if (oldRDDs.size > 0) { + invReduceFunc(previousRDD, Some(ssc.sc.union(oldRDDs)), validTime) + } else { + previousRDD + } + + // add the RDDs of the reduced values in "new time steps" + val newRDDs = parent.slice(previous.endTime + parent.slideDuration, current.endTime) + if (newRDDs.size > 0) { + func(subtracted, Some(ssc.sc.union(newRDDs)), validTime) + } else { + subtracted + } + } else { + // Get the RDDs of the reduced values in current window + val currentRDDs = parent.slice(current.beginTime + parent.slideDuration, current.endTime) + if (currentRDDs.size > 0) { + func(None, Some(ssc.sc.union(currentRDDs)), validTime) + } else { + None + } + } + } +} --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org