Repository: spark Updated Branches: refs/heads/master f6255d7b7 -> 03306a6df
http://git-wip-us.apache.org/repos/asf/spark/blob/03306a6d/python/pyspark/tests/test_readwrite.py ---------------------------------------------------------------------- diff --git a/python/pyspark/tests/test_readwrite.py b/python/pyspark/tests/test_readwrite.py new file mode 100644 index 0000000..e45f5b3 --- /dev/null +++ b/python/pyspark/tests/test_readwrite.py @@ -0,0 +1,499 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +import os +import shutil +import sys +import tempfile +import unittest +from array import array + +from pyspark.testing.utils import ReusedPySparkTestCase, SPARK_HOME + + +class InputFormatTests(ReusedPySparkTestCase): + + @classmethod + def setUpClass(cls): + ReusedPySparkTestCase.setUpClass() + cls.tempdir = tempfile.NamedTemporaryFile(delete=False) + os.unlink(cls.tempdir.name) + cls.sc._jvm.WriteInputFormatTestDataGenerator.generateData(cls.tempdir.name, cls.sc._jsc) + + @classmethod + def tearDownClass(cls): + ReusedPySparkTestCase.tearDownClass() + shutil.rmtree(cls.tempdir.name) + + @unittest.skipIf(sys.version >= "3", "serialize array of byte") + def test_sequencefiles(self): + basepath = self.tempdir.name + ints = sorted(self.sc.sequenceFile(basepath + "/sftestdata/sfint/", + "org.apache.hadoop.io.IntWritable", + "org.apache.hadoop.io.Text").collect()) + ei = [(1, u'aa'), (1, u'aa'), (2, u'aa'), (2, u'bb'), (2, u'bb'), (3, u'cc')] + self.assertEqual(ints, ei) + + doubles = sorted(self.sc.sequenceFile(basepath + "/sftestdata/sfdouble/", + "org.apache.hadoop.io.DoubleWritable", + "org.apache.hadoop.io.Text").collect()) + ed = [(1.0, u'aa'), (1.0, u'aa'), (2.0, u'aa'), (2.0, u'bb'), (2.0, u'bb'), (3.0, u'cc')] + self.assertEqual(doubles, ed) + + bytes = sorted(self.sc.sequenceFile(basepath + "/sftestdata/sfbytes/", + "org.apache.hadoop.io.IntWritable", + "org.apache.hadoop.io.BytesWritable").collect()) + ebs = [(1, bytearray('aa', 'utf-8')), + (1, bytearray('aa', 'utf-8')), + (2, bytearray('aa', 'utf-8')), + (2, bytearray('bb', 'utf-8')), + (2, bytearray('bb', 'utf-8')), + (3, bytearray('cc', 'utf-8'))] + self.assertEqual(bytes, ebs) + + text = sorted(self.sc.sequenceFile(basepath + "/sftestdata/sftext/", + "org.apache.hadoop.io.Text", + "org.apache.hadoop.io.Text").collect()) + et = [(u'1', u'aa'), + (u'1', u'aa'), + (u'2', u'aa'), + (u'2', u'bb'), + (u'2', u'bb'), + (u'3', u'cc')] + self.assertEqual(text, et) + + bools = sorted(self.sc.sequenceFile(basepath + "/sftestdata/sfbool/", + "org.apache.hadoop.io.IntWritable", + "org.apache.hadoop.io.BooleanWritable").collect()) + eb = [(1, False), (1, True), (2, False), (2, False), (2, True), (3, True)] + self.assertEqual(bools, eb) + + nulls = sorted(self.sc.sequenceFile(basepath + "/sftestdata/sfnull/", + "org.apache.hadoop.io.IntWritable", + "org.apache.hadoop.io.BooleanWritable").collect()) + en = [(1, None), (1, None), (2, None), (2, None), (2, None), (3, None)] + self.assertEqual(nulls, en) + + maps = self.sc.sequenceFile(basepath + "/sftestdata/sfmap/", + "org.apache.hadoop.io.IntWritable", + "org.apache.hadoop.io.MapWritable").collect() + em = [(1, {}), + (1, {3.0: u'bb'}), + (2, {1.0: u'aa'}), + (2, {1.0: u'cc'}), + (3, {2.0: u'dd'})] + for v in maps: + self.assertTrue(v in em) + + # arrays get pickled to tuples by default + tuples = sorted(self.sc.sequenceFile( + basepath + "/sftestdata/sfarray/", + "org.apache.hadoop.io.IntWritable", + "org.apache.spark.api.python.DoubleArrayWritable").collect()) + et = [(1, ()), + (2, (3.0, 4.0, 5.0)), + (3, (4.0, 5.0, 6.0))] + self.assertEqual(tuples, et) + + # with custom converters, primitive arrays can stay as arrays + arrays = sorted(self.sc.sequenceFile( + basepath + "/sftestdata/sfarray/", + "org.apache.hadoop.io.IntWritable", + "org.apache.spark.api.python.DoubleArrayWritable", + valueConverter="org.apache.spark.api.python.WritableToDoubleArrayConverter").collect()) + ea = [(1, array('d')), + (2, array('d', [3.0, 4.0, 5.0])), + (3, array('d', [4.0, 5.0, 6.0]))] + self.assertEqual(arrays, ea) + + clazz = sorted(self.sc.sequenceFile(basepath + "/sftestdata/sfclass/", + "org.apache.hadoop.io.Text", + "org.apache.spark.api.python.TestWritable").collect()) + cname = u'org.apache.spark.api.python.TestWritable' + ec = [(u'1', {u'__class__': cname, u'double': 1.0, u'int': 1, u'str': u'test1'}), + (u'2', {u'__class__': cname, u'double': 2.3, u'int': 2, u'str': u'test2'}), + (u'3', {u'__class__': cname, u'double': 3.1, u'int': 3, u'str': u'test3'}), + (u'4', {u'__class__': cname, u'double': 4.2, u'int': 4, u'str': u'test4'}), + (u'5', {u'__class__': cname, u'double': 5.5, u'int': 5, u'str': u'test56'})] + self.assertEqual(clazz, ec) + + unbatched_clazz = sorted(self.sc.sequenceFile(basepath + "/sftestdata/sfclass/", + "org.apache.hadoop.io.Text", + "org.apache.spark.api.python.TestWritable", + ).collect()) + self.assertEqual(unbatched_clazz, ec) + + def test_oldhadoop(self): + basepath = self.tempdir.name + ints = sorted(self.sc.hadoopFile(basepath + "/sftestdata/sfint/", + "org.apache.hadoop.mapred.SequenceFileInputFormat", + "org.apache.hadoop.io.IntWritable", + "org.apache.hadoop.io.Text").collect()) + ei = [(1, u'aa'), (1, u'aa'), (2, u'aa'), (2, u'bb'), (2, u'bb'), (3, u'cc')] + self.assertEqual(ints, ei) + + hellopath = os.path.join(SPARK_HOME, "python/test_support/hello/hello.txt") + oldconf = {"mapreduce.input.fileinputformat.inputdir": hellopath} + hello = self.sc.hadoopRDD("org.apache.hadoop.mapred.TextInputFormat", + "org.apache.hadoop.io.LongWritable", + "org.apache.hadoop.io.Text", + conf=oldconf).collect() + result = [(0, u'Hello World!')] + self.assertEqual(hello, result) + + def test_newhadoop(self): + basepath = self.tempdir.name + ints = sorted(self.sc.newAPIHadoopFile( + basepath + "/sftestdata/sfint/", + "org.apache.hadoop.mapreduce.lib.input.SequenceFileInputFormat", + "org.apache.hadoop.io.IntWritable", + "org.apache.hadoop.io.Text").collect()) + ei = [(1, u'aa'), (1, u'aa'), (2, u'aa'), (2, u'bb'), (2, u'bb'), (3, u'cc')] + self.assertEqual(ints, ei) + + hellopath = os.path.join(SPARK_HOME, "python/test_support/hello/hello.txt") + newconf = {"mapreduce.input.fileinputformat.inputdir": hellopath} + hello = self.sc.newAPIHadoopRDD("org.apache.hadoop.mapreduce.lib.input.TextInputFormat", + "org.apache.hadoop.io.LongWritable", + "org.apache.hadoop.io.Text", + conf=newconf).collect() + result = [(0, u'Hello World!')] + self.assertEqual(hello, result) + + def test_newolderror(self): + basepath = self.tempdir.name + self.assertRaises(Exception, lambda: self.sc.hadoopFile( + basepath + "/sftestdata/sfint/", + "org.apache.hadoop.mapreduce.lib.input.SequenceFileInputFormat", + "org.apache.hadoop.io.IntWritable", + "org.apache.hadoop.io.Text")) + + self.assertRaises(Exception, lambda: self.sc.newAPIHadoopFile( + basepath + "/sftestdata/sfint/", + "org.apache.hadoop.mapred.SequenceFileInputFormat", + "org.apache.hadoop.io.IntWritable", + "org.apache.hadoop.io.Text")) + + def test_bad_inputs(self): + basepath = self.tempdir.name + self.assertRaises(Exception, lambda: self.sc.sequenceFile( + basepath + "/sftestdata/sfint/", + "org.apache.hadoop.io.NotValidWritable", + "org.apache.hadoop.io.Text")) + self.assertRaises(Exception, lambda: self.sc.hadoopFile( + basepath + "/sftestdata/sfint/", + "org.apache.hadoop.mapred.NotValidInputFormat", + "org.apache.hadoop.io.IntWritable", + "org.apache.hadoop.io.Text")) + self.assertRaises(Exception, lambda: self.sc.newAPIHadoopFile( + basepath + "/sftestdata/sfint/", + "org.apache.hadoop.mapreduce.lib.input.NotValidInputFormat", + "org.apache.hadoop.io.IntWritable", + "org.apache.hadoop.io.Text")) + + def test_converters(self): + # use of custom converters + basepath = self.tempdir.name + maps = sorted(self.sc.sequenceFile( + basepath + "/sftestdata/sfmap/", + "org.apache.hadoop.io.IntWritable", + "org.apache.hadoop.io.MapWritable", + keyConverter="org.apache.spark.api.python.TestInputKeyConverter", + valueConverter="org.apache.spark.api.python.TestInputValueConverter").collect()) + em = [(u'\x01', []), + (u'\x01', [3.0]), + (u'\x02', [1.0]), + (u'\x02', [1.0]), + (u'\x03', [2.0])] + self.assertEqual(maps, em) + + def test_binary_files(self): + path = os.path.join(self.tempdir.name, "binaryfiles") + os.mkdir(path) + data = b"short binary data" + with open(os.path.join(path, "part-0000"), 'wb') as f: + f.write(data) + [(p, d)] = self.sc.binaryFiles(path).collect() + self.assertTrue(p.endswith("part-0000")) + self.assertEqual(d, data) + + def test_binary_records(self): + path = os.path.join(self.tempdir.name, "binaryrecords") + os.mkdir(path) + with open(os.path.join(path, "part-0000"), 'w') as f: + for i in range(100): + f.write('%04d' % i) + result = self.sc.binaryRecords(path, 4).map(int).collect() + self.assertEqual(list(range(100)), result) + + +class OutputFormatTests(ReusedPySparkTestCase): + + def setUp(self): + self.tempdir = tempfile.NamedTemporaryFile(delete=False) + os.unlink(self.tempdir.name) + + def tearDown(self): + shutil.rmtree(self.tempdir.name, ignore_errors=True) + + @unittest.skipIf(sys.version >= "3", "serialize array of byte") + def test_sequencefiles(self): + basepath = self.tempdir.name + ei = [(1, u'aa'), (1, u'aa'), (2, u'aa'), (2, u'bb'), (2, u'bb'), (3, u'cc')] + self.sc.parallelize(ei).saveAsSequenceFile(basepath + "/sfint/") + ints = sorted(self.sc.sequenceFile(basepath + "/sfint/").collect()) + self.assertEqual(ints, ei) + + ed = [(1.0, u'aa'), (1.0, u'aa'), (2.0, u'aa'), (2.0, u'bb'), (2.0, u'bb'), (3.0, u'cc')] + self.sc.parallelize(ed).saveAsSequenceFile(basepath + "/sfdouble/") + doubles = sorted(self.sc.sequenceFile(basepath + "/sfdouble/").collect()) + self.assertEqual(doubles, ed) + + ebs = [(1, bytearray(b'\x00\x07spam\x08')), (2, bytearray(b'\x00\x07spam\x08'))] + self.sc.parallelize(ebs).saveAsSequenceFile(basepath + "/sfbytes/") + bytes = sorted(self.sc.sequenceFile(basepath + "/sfbytes/").collect()) + self.assertEqual(bytes, ebs) + + et = [(u'1', u'aa'), + (u'2', u'bb'), + (u'3', u'cc')] + self.sc.parallelize(et).saveAsSequenceFile(basepath + "/sftext/") + text = sorted(self.sc.sequenceFile(basepath + "/sftext/").collect()) + self.assertEqual(text, et) + + eb = [(1, False), (1, True), (2, False), (2, False), (2, True), (3, True)] + self.sc.parallelize(eb).saveAsSequenceFile(basepath + "/sfbool/") + bools = sorted(self.sc.sequenceFile(basepath + "/sfbool/").collect()) + self.assertEqual(bools, eb) + + en = [(1, None), (1, None), (2, None), (2, None), (2, None), (3, None)] + self.sc.parallelize(en).saveAsSequenceFile(basepath + "/sfnull/") + nulls = sorted(self.sc.sequenceFile(basepath + "/sfnull/").collect()) + self.assertEqual(nulls, en) + + em = [(1, {}), + (1, {3.0: u'bb'}), + (2, {1.0: u'aa'}), + (2, {1.0: u'cc'}), + (3, {2.0: u'dd'})] + self.sc.parallelize(em).saveAsSequenceFile(basepath + "/sfmap/") + maps = self.sc.sequenceFile(basepath + "/sfmap/").collect() + for v in maps: + self.assertTrue(v, em) + + def test_oldhadoop(self): + basepath = self.tempdir.name + dict_data = [(1, {}), + (1, {"row1": 1.0}), + (2, {"row2": 2.0})] + self.sc.parallelize(dict_data).saveAsHadoopFile( + basepath + "/oldhadoop/", + "org.apache.hadoop.mapred.SequenceFileOutputFormat", + "org.apache.hadoop.io.IntWritable", + "org.apache.hadoop.io.MapWritable") + result = self.sc.hadoopFile( + basepath + "/oldhadoop/", + "org.apache.hadoop.mapred.SequenceFileInputFormat", + "org.apache.hadoop.io.IntWritable", + "org.apache.hadoop.io.MapWritable").collect() + for v in result: + self.assertTrue(v, dict_data) + + conf = { + "mapred.output.format.class": "org.apache.hadoop.mapred.SequenceFileOutputFormat", + "mapreduce.job.output.key.class": "org.apache.hadoop.io.IntWritable", + "mapreduce.job.output.value.class": "org.apache.hadoop.io.MapWritable", + "mapreduce.output.fileoutputformat.outputdir": basepath + "/olddataset/" + } + self.sc.parallelize(dict_data).saveAsHadoopDataset(conf) + input_conf = {"mapreduce.input.fileinputformat.inputdir": basepath + "/olddataset/"} + result = self.sc.hadoopRDD( + "org.apache.hadoop.mapred.SequenceFileInputFormat", + "org.apache.hadoop.io.IntWritable", + "org.apache.hadoop.io.MapWritable", + conf=input_conf).collect() + for v in result: + self.assertTrue(v, dict_data) + + def test_newhadoop(self): + basepath = self.tempdir.name + data = [(1, ""), + (1, "a"), + (2, "bcdf")] + self.sc.parallelize(data).saveAsNewAPIHadoopFile( + basepath + "/newhadoop/", + "org.apache.hadoop.mapreduce.lib.output.SequenceFileOutputFormat", + "org.apache.hadoop.io.IntWritable", + "org.apache.hadoop.io.Text") + result = sorted(self.sc.newAPIHadoopFile( + basepath + "/newhadoop/", + "org.apache.hadoop.mapreduce.lib.input.SequenceFileInputFormat", + "org.apache.hadoop.io.IntWritable", + "org.apache.hadoop.io.Text").collect()) + self.assertEqual(result, data) + + conf = { + "mapreduce.job.outputformat.class": + "org.apache.hadoop.mapreduce.lib.output.SequenceFileOutputFormat", + "mapreduce.job.output.key.class": "org.apache.hadoop.io.IntWritable", + "mapreduce.job.output.value.class": "org.apache.hadoop.io.Text", + "mapreduce.output.fileoutputformat.outputdir": basepath + "/newdataset/" + } + self.sc.parallelize(data).saveAsNewAPIHadoopDataset(conf) + input_conf = {"mapreduce.input.fileinputformat.inputdir": basepath + "/newdataset/"} + new_dataset = sorted(self.sc.newAPIHadoopRDD( + "org.apache.hadoop.mapreduce.lib.input.SequenceFileInputFormat", + "org.apache.hadoop.io.IntWritable", + "org.apache.hadoop.io.Text", + conf=input_conf).collect()) + self.assertEqual(new_dataset, data) + + @unittest.skipIf(sys.version >= "3", "serialize of array") + def test_newhadoop_with_array(self): + basepath = self.tempdir.name + # use custom ArrayWritable types and converters to handle arrays + array_data = [(1, array('d')), + (1, array('d', [1.0, 2.0, 3.0])), + (2, array('d', [3.0, 4.0, 5.0]))] + self.sc.parallelize(array_data).saveAsNewAPIHadoopFile( + basepath + "/newhadoop/", + "org.apache.hadoop.mapreduce.lib.output.SequenceFileOutputFormat", + "org.apache.hadoop.io.IntWritable", + "org.apache.spark.api.python.DoubleArrayWritable", + valueConverter="org.apache.spark.api.python.DoubleArrayToWritableConverter") + result = sorted(self.sc.newAPIHadoopFile( + basepath + "/newhadoop/", + "org.apache.hadoop.mapreduce.lib.input.SequenceFileInputFormat", + "org.apache.hadoop.io.IntWritable", + "org.apache.spark.api.python.DoubleArrayWritable", + valueConverter="org.apache.spark.api.python.WritableToDoubleArrayConverter").collect()) + self.assertEqual(result, array_data) + + conf = { + "mapreduce.job.outputformat.class": + "org.apache.hadoop.mapreduce.lib.output.SequenceFileOutputFormat", + "mapreduce.job.output.key.class": "org.apache.hadoop.io.IntWritable", + "mapreduce.job.output.value.class": "org.apache.spark.api.python.DoubleArrayWritable", + "mapreduce.output.fileoutputformat.outputdir": basepath + "/newdataset/" + } + self.sc.parallelize(array_data).saveAsNewAPIHadoopDataset( + conf, + valueConverter="org.apache.spark.api.python.DoubleArrayToWritableConverter") + input_conf = {"mapreduce.input.fileinputformat.inputdir": basepath + "/newdataset/"} + new_dataset = sorted(self.sc.newAPIHadoopRDD( + "org.apache.hadoop.mapreduce.lib.input.SequenceFileInputFormat", + "org.apache.hadoop.io.IntWritable", + "org.apache.spark.api.python.DoubleArrayWritable", + valueConverter="org.apache.spark.api.python.WritableToDoubleArrayConverter", + conf=input_conf).collect()) + self.assertEqual(new_dataset, array_data) + + def test_newolderror(self): + basepath = self.tempdir.name + rdd = self.sc.parallelize(range(1, 4)).map(lambda x: (x, "a" * x)) + self.assertRaises(Exception, lambda: rdd.saveAsHadoopFile( + basepath + "/newolderror/saveAsHadoopFile/", + "org.apache.hadoop.mapreduce.lib.output.SequenceFileOutputFormat")) + self.assertRaises(Exception, lambda: rdd.saveAsNewAPIHadoopFile( + basepath + "/newolderror/saveAsNewAPIHadoopFile/", + "org.apache.hadoop.mapred.SequenceFileOutputFormat")) + + def test_bad_inputs(self): + basepath = self.tempdir.name + rdd = self.sc.parallelize(range(1, 4)).map(lambda x: (x, "a" * x)) + self.assertRaises(Exception, lambda: rdd.saveAsHadoopFile( + basepath + "/badinputs/saveAsHadoopFile/", + "org.apache.hadoop.mapred.NotValidOutputFormat")) + self.assertRaises(Exception, lambda: rdd.saveAsNewAPIHadoopFile( + basepath + "/badinputs/saveAsNewAPIHadoopFile/", + "org.apache.hadoop.mapreduce.lib.output.NotValidOutputFormat")) + + def test_converters(self): + # use of custom converters + basepath = self.tempdir.name + data = [(1, {3.0: u'bb'}), + (2, {1.0: u'aa'}), + (3, {2.0: u'dd'})] + self.sc.parallelize(data).saveAsNewAPIHadoopFile( + basepath + "/converters/", + "org.apache.hadoop.mapreduce.lib.output.SequenceFileOutputFormat", + keyConverter="org.apache.spark.api.python.TestOutputKeyConverter", + valueConverter="org.apache.spark.api.python.TestOutputValueConverter") + converted = sorted(self.sc.sequenceFile(basepath + "/converters/").collect()) + expected = [(u'1', 3.0), + (u'2', 1.0), + (u'3', 2.0)] + self.assertEqual(converted, expected) + + def test_reserialization(self): + basepath = self.tempdir.name + x = range(1, 5) + y = range(1001, 1005) + data = list(zip(x, y)) + rdd = self.sc.parallelize(x).zip(self.sc.parallelize(y)) + rdd.saveAsSequenceFile(basepath + "/reserialize/sequence") + result1 = sorted(self.sc.sequenceFile(basepath + "/reserialize/sequence").collect()) + self.assertEqual(result1, data) + + rdd.saveAsHadoopFile( + basepath + "/reserialize/hadoop", + "org.apache.hadoop.mapred.SequenceFileOutputFormat") + result2 = sorted(self.sc.sequenceFile(basepath + "/reserialize/hadoop").collect()) + self.assertEqual(result2, data) + + rdd.saveAsNewAPIHadoopFile( + basepath + "/reserialize/newhadoop", + "org.apache.hadoop.mapreduce.lib.output.SequenceFileOutputFormat") + result3 = sorted(self.sc.sequenceFile(basepath + "/reserialize/newhadoop").collect()) + self.assertEqual(result3, data) + + conf4 = { + "mapred.output.format.class": "org.apache.hadoop.mapred.SequenceFileOutputFormat", + "mapreduce.job.output.key.class": "org.apache.hadoop.io.IntWritable", + "mapreduce.job.output.value.class": "org.apache.hadoop.io.IntWritable", + "mapreduce.output.fileoutputformat.outputdir": basepath + "/reserialize/dataset"} + rdd.saveAsHadoopDataset(conf4) + result4 = sorted(self.sc.sequenceFile(basepath + "/reserialize/dataset").collect()) + self.assertEqual(result4, data) + + conf5 = {"mapreduce.job.outputformat.class": + "org.apache.hadoop.mapreduce.lib.output.SequenceFileOutputFormat", + "mapreduce.job.output.key.class": "org.apache.hadoop.io.IntWritable", + "mapreduce.job.output.value.class": "org.apache.hadoop.io.IntWritable", + "mapreduce.output.fileoutputformat.outputdir": basepath + "/reserialize/newdataset" + } + rdd.saveAsNewAPIHadoopDataset(conf5) + result5 = sorted(self.sc.sequenceFile(basepath + "/reserialize/newdataset").collect()) + self.assertEqual(result5, data) + + def test_malformed_RDD(self): + basepath = self.tempdir.name + # non-batch-serialized RDD[[(K, V)]] should be rejected + data = [[(1, "a")], [(2, "aa")], [(3, "aaa")]] + rdd = self.sc.parallelize(data, len(data)) + self.assertRaises(Exception, lambda: rdd.saveAsSequenceFile( + basepath + "/malformed/sequence")) + + +if __name__ == "__main__": + from pyspark.tests.test_readwrite import * + + try: + import xmlrunner + testRunner = xmlrunner.XMLTestRunner(output='target/test-reports') + except ImportError: + testRunner = None + unittest.main(testRunner=testRunner, verbosity=2) http://git-wip-us.apache.org/repos/asf/spark/blob/03306a6d/python/pyspark/tests/test_serializers.py ---------------------------------------------------------------------- diff --git a/python/pyspark/tests/test_serializers.py b/python/pyspark/tests/test_serializers.py new file mode 100644 index 0000000..bce9406 --- /dev/null +++ b/python/pyspark/tests/test_serializers.py @@ -0,0 +1,237 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +import math +import sys +import unittest + +from pyspark import serializers +from pyspark.serializers import * +from pyspark.serializers import CloudPickleSerializer, CompressedSerializer, \ + AutoBatchedSerializer, BatchedSerializer, AutoSerializer, NoOpSerializer, PairDeserializer, \ + FlattenedValuesSerializer, CartesianDeserializer +from pyspark.testing.utils import PySparkTestCase, read_int, write_int, ByteArrayOutput, \ + have_numpy, have_scipy + + +class SerializationTestCase(unittest.TestCase): + + def test_namedtuple(self): + from collections import namedtuple + from pickle import dumps, loads + P = namedtuple("P", "x y") + p1 = P(1, 3) + p2 = loads(dumps(p1, 2)) + self.assertEqual(p1, p2) + + from pyspark.cloudpickle import dumps + P2 = loads(dumps(P)) + p3 = P2(1, 3) + self.assertEqual(p1, p3) + + def test_itemgetter(self): + from operator import itemgetter + ser = CloudPickleSerializer() + d = range(10) + getter = itemgetter(1) + getter2 = ser.loads(ser.dumps(getter)) + self.assertEqual(getter(d), getter2(d)) + + getter = itemgetter(0, 3) + getter2 = ser.loads(ser.dumps(getter)) + self.assertEqual(getter(d), getter2(d)) + + def test_function_module_name(self): + ser = CloudPickleSerializer() + func = lambda x: x + func2 = ser.loads(ser.dumps(func)) + self.assertEqual(func.__module__, func2.__module__) + + def test_attrgetter(self): + from operator import attrgetter + ser = CloudPickleSerializer() + + class C(object): + def __getattr__(self, item): + return item + d = C() + getter = attrgetter("a") + getter2 = ser.loads(ser.dumps(getter)) + self.assertEqual(getter(d), getter2(d)) + getter = attrgetter("a", "b") + getter2 = ser.loads(ser.dumps(getter)) + self.assertEqual(getter(d), getter2(d)) + + d.e = C() + getter = attrgetter("e.a") + getter2 = ser.loads(ser.dumps(getter)) + self.assertEqual(getter(d), getter2(d)) + getter = attrgetter("e.a", "e.b") + getter2 = ser.loads(ser.dumps(getter)) + self.assertEqual(getter(d), getter2(d)) + + # Regression test for SPARK-3415 + def test_pickling_file_handles(self): + # to be corrected with SPARK-11160 + try: + import xmlrunner + except ImportError: + ser = CloudPickleSerializer() + out1 = sys.stderr + out2 = ser.loads(ser.dumps(out1)) + self.assertEqual(out1, out2) + + def test_func_globals(self): + + class Unpicklable(object): + def __reduce__(self): + raise Exception("not picklable") + + global exit + exit = Unpicklable() + + ser = CloudPickleSerializer() + self.assertRaises(Exception, lambda: ser.dumps(exit)) + + def foo(): + sys.exit(0) + + self.assertTrue("exit" in foo.__code__.co_names) + ser.dumps(foo) + + def test_compressed_serializer(self): + ser = CompressedSerializer(PickleSerializer()) + try: + from StringIO import StringIO + except ImportError: + from io import BytesIO as StringIO + io = StringIO() + ser.dump_stream(["abc", u"123", range(5)], io) + io.seek(0) + self.assertEqual(["abc", u"123", range(5)], list(ser.load_stream(io))) + ser.dump_stream(range(1000), io) + io.seek(0) + self.assertEqual(["abc", u"123", range(5)] + list(range(1000)), list(ser.load_stream(io))) + io.close() + + def test_hash_serializer(self): + hash(NoOpSerializer()) + hash(UTF8Deserializer()) + hash(PickleSerializer()) + hash(MarshalSerializer()) + hash(AutoSerializer()) + hash(BatchedSerializer(PickleSerializer())) + hash(AutoBatchedSerializer(MarshalSerializer())) + hash(PairDeserializer(NoOpSerializer(), UTF8Deserializer())) + hash(CartesianDeserializer(NoOpSerializer(), UTF8Deserializer())) + hash(CompressedSerializer(PickleSerializer())) + hash(FlattenedValuesSerializer(PickleSerializer())) + + +@unittest.skipIf(not have_scipy, "SciPy not installed") +class SciPyTests(PySparkTestCase): + + """General PySpark tests that depend on scipy """ + + def test_serialize(self): + from scipy.special import gammaln + + x = range(1, 5) + expected = list(map(gammaln, x)) + observed = self.sc.parallelize(x).map(gammaln).collect() + self.assertEqual(expected, observed) + + +@unittest.skipIf(not have_numpy, "NumPy not installed") +class NumPyTests(PySparkTestCase): + + """General PySpark tests that depend on numpy """ + + def test_statcounter_array(self): + import numpy as np + + x = self.sc.parallelize([np.array([1.0, 1.0]), np.array([2.0, 2.0]), np.array([3.0, 3.0])]) + s = x.stats() + self.assertSequenceEqual([2.0, 2.0], s.mean().tolist()) + self.assertSequenceEqual([1.0, 1.0], s.min().tolist()) + self.assertSequenceEqual([3.0, 3.0], s.max().tolist()) + self.assertSequenceEqual([1.0, 1.0], s.sampleStdev().tolist()) + + stats_dict = s.asDict() + self.assertEqual(3, stats_dict['count']) + self.assertSequenceEqual([2.0, 2.0], stats_dict['mean'].tolist()) + self.assertSequenceEqual([1.0, 1.0], stats_dict['min'].tolist()) + self.assertSequenceEqual([3.0, 3.0], stats_dict['max'].tolist()) + self.assertSequenceEqual([6.0, 6.0], stats_dict['sum'].tolist()) + self.assertSequenceEqual([1.0, 1.0], stats_dict['stdev'].tolist()) + self.assertSequenceEqual([1.0, 1.0], stats_dict['variance'].tolist()) + + stats_sample_dict = s.asDict(sample=True) + self.assertEqual(3, stats_dict['count']) + self.assertSequenceEqual([2.0, 2.0], stats_sample_dict['mean'].tolist()) + self.assertSequenceEqual([1.0, 1.0], stats_sample_dict['min'].tolist()) + self.assertSequenceEqual([3.0, 3.0], stats_sample_dict['max'].tolist()) + self.assertSequenceEqual([6.0, 6.0], stats_sample_dict['sum'].tolist()) + self.assertSequenceEqual( + [0.816496580927726, 0.816496580927726], stats_sample_dict['stdev'].tolist()) + self.assertSequenceEqual( + [0.6666666666666666, 0.6666666666666666], stats_sample_dict['variance'].tolist()) + + +class SerializersTest(unittest.TestCase): + + def test_chunked_stream(self): + original_bytes = bytearray(range(100)) + for data_length in [1, 10, 100]: + for buffer_length in [1, 2, 3, 5, 20, 99, 100, 101, 500]: + dest = ByteArrayOutput() + stream_out = serializers.ChunkedStream(dest, buffer_length) + stream_out.write(original_bytes[:data_length]) + stream_out.close() + num_chunks = int(math.ceil(float(data_length) / buffer_length)) + # length for each chunk, and a final -1 at the very end + exp_size = (num_chunks + 1) * 4 + data_length + self.assertEqual(len(dest.buffer), exp_size) + dest_pos = 0 + data_pos = 0 + for chunk_idx in range(num_chunks): + chunk_length = read_int(dest.buffer[dest_pos:(dest_pos + 4)]) + if chunk_idx == num_chunks - 1: + exp_length = data_length % buffer_length + if exp_length == 0: + exp_length = buffer_length + else: + exp_length = buffer_length + self.assertEqual(chunk_length, exp_length) + dest_pos += 4 + dest_chunk = dest.buffer[dest_pos:dest_pos + chunk_length] + orig_chunk = original_bytes[data_pos:data_pos + chunk_length] + self.assertEqual(dest_chunk, orig_chunk) + dest_pos += chunk_length + data_pos += chunk_length + # ends with a -1 + self.assertEqual(dest.buffer[-4:], write_int(-1)) + + +if __name__ == "__main__": + from pyspark.tests.test_serializers import * + + try: + import xmlrunner + testRunner = xmlrunner.XMLTestRunner(output='target/test-reports') + except ImportError: + testRunner = None + unittest.main(testRunner=testRunner, verbosity=2) http://git-wip-us.apache.org/repos/asf/spark/blob/03306a6d/python/pyspark/tests/test_shuffle.py ---------------------------------------------------------------------- diff --git a/python/pyspark/tests/test_shuffle.py b/python/pyspark/tests/test_shuffle.py new file mode 100644 index 0000000..0489426 --- /dev/null +++ b/python/pyspark/tests/test_shuffle.py @@ -0,0 +1,181 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +import random +import sys +import unittest + +from py4j.protocol import Py4JJavaError + +from pyspark import shuffle, PickleSerializer, SparkConf, SparkContext +from pyspark.shuffle import Aggregator, ExternalMerger, ExternalSorter + +if sys.version_info[0] >= 3: + xrange = range + + +class MergerTests(unittest.TestCase): + + def setUp(self): + self.N = 1 << 12 + self.l = [i for i in xrange(self.N)] + self.data = list(zip(self.l, self.l)) + self.agg = Aggregator(lambda x: [x], + lambda x, y: x.append(y) or x, + lambda x, y: x.extend(y) or x) + + def test_small_dataset(self): + m = ExternalMerger(self.agg, 1000) + m.mergeValues(self.data) + self.assertEqual(m.spills, 0) + self.assertEqual(sum(sum(v) for k, v in m.items()), + sum(xrange(self.N))) + + m = ExternalMerger(self.agg, 1000) + m.mergeCombiners(map(lambda x_y1: (x_y1[0], [x_y1[1]]), self.data)) + self.assertEqual(m.spills, 0) + self.assertEqual(sum(sum(v) for k, v in m.items()), + sum(xrange(self.N))) + + def test_medium_dataset(self): + m = ExternalMerger(self.agg, 20) + m.mergeValues(self.data) + self.assertTrue(m.spills >= 1) + self.assertEqual(sum(sum(v) for k, v in m.items()), + sum(xrange(self.N))) + + m = ExternalMerger(self.agg, 10) + m.mergeCombiners(map(lambda x_y2: (x_y2[0], [x_y2[1]]), self.data * 3)) + self.assertTrue(m.spills >= 1) + self.assertEqual(sum(sum(v) for k, v in m.items()), + sum(xrange(self.N)) * 3) + + def test_huge_dataset(self): + m = ExternalMerger(self.agg, 5, partitions=3) + m.mergeCombiners(map(lambda k_v: (k_v[0], [str(k_v[1])]), self.data * 10)) + self.assertTrue(m.spills >= 1) + self.assertEqual(sum(len(v) for k, v in m.items()), + self.N * 10) + m._cleanup() + + def test_group_by_key(self): + + def gen_data(N, step): + for i in range(1, N + 1, step): + for j in range(i): + yield (i, [j]) + + def gen_gs(N, step=1): + return shuffle.GroupByKey(gen_data(N, step)) + + self.assertEqual(1, len(list(gen_gs(1)))) + self.assertEqual(2, len(list(gen_gs(2)))) + self.assertEqual(100, len(list(gen_gs(100)))) + self.assertEqual(list(range(1, 101)), [k for k, _ in gen_gs(100)]) + self.assertTrue(all(list(range(k)) == list(vs) for k, vs in gen_gs(100))) + + for k, vs in gen_gs(50002, 10000): + self.assertEqual(k, len(vs)) + self.assertEqual(list(range(k)), list(vs)) + + ser = PickleSerializer() + l = ser.loads(ser.dumps(list(gen_gs(50002, 30000)))) + for k, vs in l: + self.assertEqual(k, len(vs)) + self.assertEqual(list(range(k)), list(vs)) + + def test_stopiteration_is_raised(self): + + def stopit(*args, **kwargs): + raise StopIteration() + + def legit_create_combiner(x): + return [x] + + def legit_merge_value(x, y): + return x.append(y) or x + + def legit_merge_combiners(x, y): + return x.extend(y) or x + + data = [(x % 2, x) for x in range(100)] + + # wrong create combiner + m = ExternalMerger(Aggregator(stopit, legit_merge_value, legit_merge_combiners), 20) + with self.assertRaises((Py4JJavaError, RuntimeError)) as cm: + m.mergeValues(data) + + # wrong merge value + m = ExternalMerger(Aggregator(legit_create_combiner, stopit, legit_merge_combiners), 20) + with self.assertRaises((Py4JJavaError, RuntimeError)) as cm: + m.mergeValues(data) + + # wrong merge combiners + m = ExternalMerger(Aggregator(legit_create_combiner, legit_merge_value, stopit), 20) + with self.assertRaises((Py4JJavaError, RuntimeError)) as cm: + m.mergeCombiners(map(lambda x_y1: (x_y1[0], [x_y1[1]]), data)) + + +class SorterTests(unittest.TestCase): + def test_in_memory_sort(self): + l = list(range(1024)) + random.shuffle(l) + sorter = ExternalSorter(1024) + self.assertEqual(sorted(l), list(sorter.sorted(l))) + self.assertEqual(sorted(l, reverse=True), list(sorter.sorted(l, reverse=True))) + self.assertEqual(sorted(l, key=lambda x: -x), list(sorter.sorted(l, key=lambda x: -x))) + self.assertEqual(sorted(l, key=lambda x: -x, reverse=True), + list(sorter.sorted(l, key=lambda x: -x, reverse=True))) + + def test_external_sort(self): + class CustomizedSorter(ExternalSorter): + def _next_limit(self): + return self.memory_limit + l = list(range(1024)) + random.shuffle(l) + sorter = CustomizedSorter(1) + self.assertEqual(sorted(l), list(sorter.sorted(l))) + self.assertGreater(shuffle.DiskBytesSpilled, 0) + last = shuffle.DiskBytesSpilled + self.assertEqual(sorted(l, reverse=True), list(sorter.sorted(l, reverse=True))) + self.assertGreater(shuffle.DiskBytesSpilled, last) + last = shuffle.DiskBytesSpilled + self.assertEqual(sorted(l, key=lambda x: -x), list(sorter.sorted(l, key=lambda x: -x))) + self.assertGreater(shuffle.DiskBytesSpilled, last) + last = shuffle.DiskBytesSpilled + self.assertEqual(sorted(l, key=lambda x: -x, reverse=True), + list(sorter.sorted(l, key=lambda x: -x, reverse=True))) + self.assertGreater(shuffle.DiskBytesSpilled, last) + + def test_external_sort_in_rdd(self): + conf = SparkConf().set("spark.python.worker.memory", "1m") + sc = SparkContext(conf=conf) + l = list(range(10240)) + random.shuffle(l) + rdd = sc.parallelize(l, 4) + self.assertEqual(sorted(l), rdd.sortBy(lambda x: x).collect()) + sc.stop() + + +if __name__ == "__main__": + from pyspark.tests.test_shuffle import * + + try: + import xmlrunner + testRunner = xmlrunner.XMLTestRunner(output='target/test-reports') + except ImportError: + testRunner = None + unittest.main(testRunner=testRunner, verbosity=2) http://git-wip-us.apache.org/repos/asf/spark/blob/03306a6d/python/pyspark/tests/test_taskcontext.py ---------------------------------------------------------------------- diff --git a/python/pyspark/tests/test_taskcontext.py b/python/pyspark/tests/test_taskcontext.py new file mode 100644 index 0000000..b3a9674 --- /dev/null +++ b/python/pyspark/tests/test_taskcontext.py @@ -0,0 +1,161 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +import random +import sys +import time + +from pyspark import SparkContext, TaskContext, BarrierTaskContext +from pyspark.testing.utils import PySparkTestCase + + +class TaskContextTests(PySparkTestCase): + + def setUp(self): + self._old_sys_path = list(sys.path) + class_name = self.__class__.__name__ + # Allow retries even though they are normally disabled in local mode + self.sc = SparkContext('local[4, 2]', class_name) + + def test_stage_id(self): + """Test the stage ids are available and incrementing as expected.""" + rdd = self.sc.parallelize(range(10)) + stage1 = rdd.map(lambda x: TaskContext.get().stageId()).take(1)[0] + stage2 = rdd.map(lambda x: TaskContext.get().stageId()).take(1)[0] + # Test using the constructor directly rather than the get() + stage3 = rdd.map(lambda x: TaskContext().stageId()).take(1)[0] + self.assertEqual(stage1 + 1, stage2) + self.assertEqual(stage1 + 2, stage3) + self.assertEqual(stage2 + 1, stage3) + + def test_partition_id(self): + """Test the partition id.""" + rdd1 = self.sc.parallelize(range(10), 1) + rdd2 = self.sc.parallelize(range(10), 2) + pids1 = rdd1.map(lambda x: TaskContext.get().partitionId()).collect() + pids2 = rdd2.map(lambda x: TaskContext.get().partitionId()).collect() + self.assertEqual(0, pids1[0]) + self.assertEqual(0, pids1[9]) + self.assertEqual(0, pids2[0]) + self.assertEqual(1, pids2[9]) + + def test_attempt_number(self): + """Verify the attempt numbers are correctly reported.""" + rdd = self.sc.parallelize(range(10)) + # Verify a simple job with no failures + attempt_numbers = rdd.map(lambda x: TaskContext.get().attemptNumber()).collect() + map(lambda attempt: self.assertEqual(0, attempt), attempt_numbers) + + def fail_on_first(x): + """Fail on the first attempt so we get a positive attempt number""" + tc = TaskContext.get() + attempt_number = tc.attemptNumber() + partition_id = tc.partitionId() + attempt_id = tc.taskAttemptId() + if attempt_number == 0 and partition_id == 0: + raise Exception("Failing on first attempt") + else: + return [x, partition_id, attempt_number, attempt_id] + result = rdd.map(fail_on_first).collect() + # We should re-submit the first partition to it but other partitions should be attempt 0 + self.assertEqual([0, 0, 1], result[0][0:3]) + self.assertEqual([9, 3, 0], result[9][0:3]) + first_partition = filter(lambda x: x[1] == 0, result) + map(lambda x: self.assertEqual(1, x[2]), first_partition) + other_partitions = filter(lambda x: x[1] != 0, result) + map(lambda x: self.assertEqual(0, x[2]), other_partitions) + # The task attempt id should be different + self.assertTrue(result[0][3] != result[9][3]) + + def test_tc_on_driver(self): + """Verify that getting the TaskContext on the driver returns None.""" + tc = TaskContext.get() + self.assertTrue(tc is None) + + def test_get_local_property(self): + """Verify that local properties set on the driver are available in TaskContext.""" + key = "testkey" + value = "testvalue" + self.sc.setLocalProperty(key, value) + try: + rdd = self.sc.parallelize(range(1), 1) + prop1 = rdd.map(lambda _: TaskContext.get().getLocalProperty(key)).collect()[0] + self.assertEqual(prop1, value) + prop2 = rdd.map(lambda _: TaskContext.get().getLocalProperty("otherkey")).collect()[0] + self.assertTrue(prop2 is None) + finally: + self.sc.setLocalProperty(key, None) + + def test_barrier(self): + """ + Verify that BarrierTaskContext.barrier() performs global sync among all barrier tasks + within a stage. + """ + rdd = self.sc.parallelize(range(10), 4) + + def f(iterator): + yield sum(iterator) + + def context_barrier(x): + tc = BarrierTaskContext.get() + time.sleep(random.randint(1, 10)) + tc.barrier() + return time.time() + + times = rdd.barrier().mapPartitions(f).map(context_barrier).collect() + self.assertTrue(max(times) - min(times) < 1) + + def test_barrier_with_python_worker_reuse(self): + """ + Verify that BarrierTaskContext.barrier() with reused python worker. + """ + self.sc._conf.set("spark.python.work.reuse", "true") + rdd = self.sc.parallelize(range(4), 4) + # start a normal job first to start all worker + result = rdd.map(lambda x: x ** 2).collect() + self.assertEqual([0, 1, 4, 9], result) + # make sure `spark.python.work.reuse=true` + self.assertEqual(self.sc._conf.get("spark.python.work.reuse"), "true") + + # worker will be reused in this barrier job + self.test_barrier() + + def test_barrier_infos(self): + """ + Verify that BarrierTaskContext.getTaskInfos() returns a list of all task infos in the + barrier stage. + """ + rdd = self.sc.parallelize(range(10), 4) + + def f(iterator): + yield sum(iterator) + + taskInfos = rdd.barrier().mapPartitions(f).map(lambda x: BarrierTaskContext.get() + .getTaskInfos()).collect() + self.assertTrue(len(taskInfos) == 4) + self.assertTrue(len(taskInfos[0]) == 4) + + +if __name__ == "__main__": + import unittest + from pyspark.tests.test_taskcontext import * + + try: + import xmlrunner + testRunner = xmlrunner.XMLTestRunner(output='target/test-reports') + except ImportError: + testRunner = None + unittest.main(testRunner=testRunner, verbosity=2) http://git-wip-us.apache.org/repos/asf/spark/blob/03306a6d/python/pyspark/tests/test_util.py ---------------------------------------------------------------------- diff --git a/python/pyspark/tests/test_util.py b/python/pyspark/tests/test_util.py new file mode 100644 index 0000000..11cda8f --- /dev/null +++ b/python/pyspark/tests/test_util.py @@ -0,0 +1,86 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +import unittest + +from py4j.protocol import Py4JJavaError + +from pyspark import keyword_only +from pyspark.testing.utils import PySparkTestCase + + +class KeywordOnlyTests(unittest.TestCase): + class Wrapped(object): + @keyword_only + def set(self, x=None, y=None): + if "x" in self._input_kwargs: + self._x = self._input_kwargs["x"] + if "y" in self._input_kwargs: + self._y = self._input_kwargs["y"] + return x, y + + def test_keywords(self): + w = self.Wrapped() + x, y = w.set(y=1) + self.assertEqual(y, 1) + self.assertEqual(y, w._y) + self.assertIsNone(x) + self.assertFalse(hasattr(w, "_x")) + + def test_non_keywords(self): + w = self.Wrapped() + self.assertRaises(TypeError, lambda: w.set(0, y=1)) + + def test_kwarg_ownership(self): + # test _input_kwargs is owned by each class instance and not a shared static variable + class Setter(object): + @keyword_only + def set(self, x=None, other=None, other_x=None): + if "other" in self._input_kwargs: + self._input_kwargs["other"].set(x=self._input_kwargs["other_x"]) + self._x = self._input_kwargs["x"] + + a = Setter() + b = Setter() + a.set(x=1, other=b, other_x=2) + self.assertEqual(a._x, 1) + self.assertEqual(b._x, 2) + + +class UtilTests(PySparkTestCase): + def test_py4j_exception_message(self): + from pyspark.util import _exception_message + + with self.assertRaises(Py4JJavaError) as context: + # This attempts java.lang.String(null) which throws an NPE. + self.sc._jvm.java.lang.String(None) + + self.assertTrue('NullPointerException' in _exception_message(context.exception)) + + def test_parsing_version_string(self): + from pyspark.util import VersionUtils + self.assertRaises(ValueError, lambda: VersionUtils.majorMinorVersion("abced")) + + +if __name__ == "__main__": + from pyspark.tests.test_util import * + + try: + import xmlrunner + testRunner = xmlrunner.XMLTestRunner(output='target/test-reports') + except ImportError: + testRunner = None + unittest.main(testRunner=testRunner, verbosity=2) http://git-wip-us.apache.org/repos/asf/spark/blob/03306a6d/python/pyspark/tests/test_worker.py ---------------------------------------------------------------------- diff --git a/python/pyspark/tests/test_worker.py b/python/pyspark/tests/test_worker.py new file mode 100644 index 0000000..a33b77d --- /dev/null +++ b/python/pyspark/tests/test_worker.py @@ -0,0 +1,157 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +import os +import sys +import tempfile +import threading +import time + +from py4j.protocol import Py4JJavaError + +from pyspark.testing.utils import ReusedPySparkTestCase, QuietTest + +if sys.version_info[0] >= 3: + xrange = range + + +class WorkerTests(ReusedPySparkTestCase): + def test_cancel_task(self): + temp = tempfile.NamedTemporaryFile(delete=True) + temp.close() + path = temp.name + + def sleep(x): + import os + import time + with open(path, 'w') as f: + f.write("%d %d" % (os.getppid(), os.getpid())) + time.sleep(100) + + # start job in background thread + def run(): + try: + self.sc.parallelize(range(1), 1).foreach(sleep) + except Exception: + pass + import threading + t = threading.Thread(target=run) + t.daemon = True + t.start() + + daemon_pid, worker_pid = 0, 0 + while True: + if os.path.exists(path): + with open(path) as f: + data = f.read().split(' ') + daemon_pid, worker_pid = map(int, data) + break + time.sleep(0.1) + + # cancel jobs + self.sc.cancelAllJobs() + t.join() + + for i in range(50): + try: + os.kill(worker_pid, 0) + time.sleep(0.1) + except OSError: + break # worker was killed + else: + self.fail("worker has not been killed after 5 seconds") + + try: + os.kill(daemon_pid, 0) + except OSError: + self.fail("daemon had been killed") + + # run a normal job + rdd = self.sc.parallelize(xrange(100), 1) + self.assertEqual(100, rdd.map(str).count()) + + def test_after_exception(self): + def raise_exception(_): + raise Exception() + rdd = self.sc.parallelize(xrange(100), 1) + with QuietTest(self.sc): + self.assertRaises(Exception, lambda: rdd.foreach(raise_exception)) + self.assertEqual(100, rdd.map(str).count()) + + def test_after_jvm_exception(self): + tempFile = tempfile.NamedTemporaryFile(delete=False) + tempFile.write(b"Hello World!") + tempFile.close() + data = self.sc.textFile(tempFile.name, 1) + filtered_data = data.filter(lambda x: True) + self.assertEqual(1, filtered_data.count()) + os.unlink(tempFile.name) + with QuietTest(self.sc): + self.assertRaises(Exception, lambda: filtered_data.count()) + + rdd = self.sc.parallelize(xrange(100), 1) + self.assertEqual(100, rdd.map(str).count()) + + def test_accumulator_when_reuse_worker(self): + from pyspark.accumulators import INT_ACCUMULATOR_PARAM + acc1 = self.sc.accumulator(0, INT_ACCUMULATOR_PARAM) + self.sc.parallelize(xrange(100), 20).foreach(lambda x: acc1.add(x)) + self.assertEqual(sum(range(100)), acc1.value) + + acc2 = self.sc.accumulator(0, INT_ACCUMULATOR_PARAM) + self.sc.parallelize(xrange(100), 20).foreach(lambda x: acc2.add(x)) + self.assertEqual(sum(range(100)), acc2.value) + self.assertEqual(sum(range(100)), acc1.value) + + def test_reuse_worker_after_take(self): + rdd = self.sc.parallelize(xrange(100000), 1) + self.assertEqual(0, rdd.first()) + + def count(): + try: + rdd.count() + except Exception: + pass + + t = threading.Thread(target=count) + t.daemon = True + t.start() + t.join(5) + self.assertTrue(not t.isAlive()) + self.assertEqual(100000, rdd.count()) + + def test_with_different_versions_of_python(self): + rdd = self.sc.parallelize(range(10)) + rdd.count() + version = self.sc.pythonVer + self.sc.pythonVer = "2.0" + try: + with QuietTest(self.sc): + self.assertRaises(Py4JJavaError, lambda: rdd.count()) + finally: + self.sc.pythonVer = version + + +if __name__ == "__main__": + import unittest + from pyspark.tests.test_worker import * + + try: + import xmlrunner + testRunner = xmlrunner.XMLTestRunner(output='target/test-reports') + except ImportError: + testRunner = None + unittest.main(testRunner=testRunner, verbosity=2) --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org