Repository: spark
Updated Branches:
  refs/heads/master f6255d7b7 -> 03306a6df


http://git-wip-us.apache.org/repos/asf/spark/blob/03306a6d/python/pyspark/tests/test_readwrite.py
----------------------------------------------------------------------
diff --git a/python/pyspark/tests/test_readwrite.py 
b/python/pyspark/tests/test_readwrite.py
new file mode 100644
index 0000000..e45f5b3
--- /dev/null
+++ b/python/pyspark/tests/test_readwrite.py
@@ -0,0 +1,499 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements.  See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License.  You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+import os
+import shutil
+import sys
+import tempfile
+import unittest
+from array import array
+
+from pyspark.testing.utils import ReusedPySparkTestCase, SPARK_HOME
+
+
+class InputFormatTests(ReusedPySparkTestCase):
+
+    @classmethod
+    def setUpClass(cls):
+        ReusedPySparkTestCase.setUpClass()
+        cls.tempdir = tempfile.NamedTemporaryFile(delete=False)
+        os.unlink(cls.tempdir.name)
+        
cls.sc._jvm.WriteInputFormatTestDataGenerator.generateData(cls.tempdir.name, 
cls.sc._jsc)
+
+    @classmethod
+    def tearDownClass(cls):
+        ReusedPySparkTestCase.tearDownClass()
+        shutil.rmtree(cls.tempdir.name)
+
+    @unittest.skipIf(sys.version >= "3", "serialize array of byte")
+    def test_sequencefiles(self):
+        basepath = self.tempdir.name
+        ints = sorted(self.sc.sequenceFile(basepath + "/sftestdata/sfint/",
+                                           "org.apache.hadoop.io.IntWritable",
+                                           
"org.apache.hadoop.io.Text").collect())
+        ei = [(1, u'aa'), (1, u'aa'), (2, u'aa'), (2, u'bb'), (2, u'bb'), (3, 
u'cc')]
+        self.assertEqual(ints, ei)
+
+        doubles = sorted(self.sc.sequenceFile(basepath + 
"/sftestdata/sfdouble/",
+                                              
"org.apache.hadoop.io.DoubleWritable",
+                                              
"org.apache.hadoop.io.Text").collect())
+        ed = [(1.0, u'aa'), (1.0, u'aa'), (2.0, u'aa'), (2.0, u'bb'), (2.0, 
u'bb'), (3.0, u'cc')]
+        self.assertEqual(doubles, ed)
+
+        bytes = sorted(self.sc.sequenceFile(basepath + "/sftestdata/sfbytes/",
+                                            "org.apache.hadoop.io.IntWritable",
+                                            
"org.apache.hadoop.io.BytesWritable").collect())
+        ebs = [(1, bytearray('aa', 'utf-8')),
+               (1, bytearray('aa', 'utf-8')),
+               (2, bytearray('aa', 'utf-8')),
+               (2, bytearray('bb', 'utf-8')),
+               (2, bytearray('bb', 'utf-8')),
+               (3, bytearray('cc', 'utf-8'))]
+        self.assertEqual(bytes, ebs)
+
+        text = sorted(self.sc.sequenceFile(basepath + "/sftestdata/sftext/",
+                                           "org.apache.hadoop.io.Text",
+                                           
"org.apache.hadoop.io.Text").collect())
+        et = [(u'1', u'aa'),
+              (u'1', u'aa'),
+              (u'2', u'aa'),
+              (u'2', u'bb'),
+              (u'2', u'bb'),
+              (u'3', u'cc')]
+        self.assertEqual(text, et)
+
+        bools = sorted(self.sc.sequenceFile(basepath + "/sftestdata/sfbool/",
+                                            "org.apache.hadoop.io.IntWritable",
+                                            
"org.apache.hadoop.io.BooleanWritable").collect())
+        eb = [(1, False), (1, True), (2, False), (2, False), (2, True), (3, 
True)]
+        self.assertEqual(bools, eb)
+
+        nulls = sorted(self.sc.sequenceFile(basepath + "/sftestdata/sfnull/",
+                                            "org.apache.hadoop.io.IntWritable",
+                                            
"org.apache.hadoop.io.BooleanWritable").collect())
+        en = [(1, None), (1, None), (2, None), (2, None), (2, None), (3, None)]
+        self.assertEqual(nulls, en)
+
+        maps = self.sc.sequenceFile(basepath + "/sftestdata/sfmap/",
+                                    "org.apache.hadoop.io.IntWritable",
+                                    
"org.apache.hadoop.io.MapWritable").collect()
+        em = [(1, {}),
+              (1, {3.0: u'bb'}),
+              (2, {1.0: u'aa'}),
+              (2, {1.0: u'cc'}),
+              (3, {2.0: u'dd'})]
+        for v in maps:
+            self.assertTrue(v in em)
+
+        # arrays get pickled to tuples by default
+        tuples = sorted(self.sc.sequenceFile(
+            basepath + "/sftestdata/sfarray/",
+            "org.apache.hadoop.io.IntWritable",
+            "org.apache.spark.api.python.DoubleArrayWritable").collect())
+        et = [(1, ()),
+              (2, (3.0, 4.0, 5.0)),
+              (3, (4.0, 5.0, 6.0))]
+        self.assertEqual(tuples, et)
+
+        # with custom converters, primitive arrays can stay as arrays
+        arrays = sorted(self.sc.sequenceFile(
+            basepath + "/sftestdata/sfarray/",
+            "org.apache.hadoop.io.IntWritable",
+            "org.apache.spark.api.python.DoubleArrayWritable",
+            
valueConverter="org.apache.spark.api.python.WritableToDoubleArrayConverter").collect())
+        ea = [(1, array('d')),
+              (2, array('d', [3.0, 4.0, 5.0])),
+              (3, array('d', [4.0, 5.0, 6.0]))]
+        self.assertEqual(arrays, ea)
+
+        clazz = sorted(self.sc.sequenceFile(basepath + "/sftestdata/sfclass/",
+                                            "org.apache.hadoop.io.Text",
+                                            
"org.apache.spark.api.python.TestWritable").collect())
+        cname = u'org.apache.spark.api.python.TestWritable'
+        ec = [(u'1', {u'__class__': cname, u'double': 1.0, u'int': 1, u'str': 
u'test1'}),
+              (u'2', {u'__class__': cname, u'double': 2.3, u'int': 2, u'str': 
u'test2'}),
+              (u'3', {u'__class__': cname, u'double': 3.1, u'int': 3, u'str': 
u'test3'}),
+              (u'4', {u'__class__': cname, u'double': 4.2, u'int': 4, u'str': 
u'test4'}),
+              (u'5', {u'__class__': cname, u'double': 5.5, u'int': 5, u'str': 
u'test56'})]
+        self.assertEqual(clazz, ec)
+
+        unbatched_clazz = sorted(self.sc.sequenceFile(basepath + 
"/sftestdata/sfclass/",
+                                                      
"org.apache.hadoop.io.Text",
+                                                      
"org.apache.spark.api.python.TestWritable",
+                                                      ).collect())
+        self.assertEqual(unbatched_clazz, ec)
+
+    def test_oldhadoop(self):
+        basepath = self.tempdir.name
+        ints = sorted(self.sc.hadoopFile(basepath + "/sftestdata/sfint/",
+                                         
"org.apache.hadoop.mapred.SequenceFileInputFormat",
+                                         "org.apache.hadoop.io.IntWritable",
+                                         
"org.apache.hadoop.io.Text").collect())
+        ei = [(1, u'aa'), (1, u'aa'), (2, u'aa'), (2, u'bb'), (2, u'bb'), (3, 
u'cc')]
+        self.assertEqual(ints, ei)
+
+        hellopath = os.path.join(SPARK_HOME, 
"python/test_support/hello/hello.txt")
+        oldconf = {"mapreduce.input.fileinputformat.inputdir": hellopath}
+        hello = self.sc.hadoopRDD("org.apache.hadoop.mapred.TextInputFormat",
+                                  "org.apache.hadoop.io.LongWritable",
+                                  "org.apache.hadoop.io.Text",
+                                  conf=oldconf).collect()
+        result = [(0, u'Hello World!')]
+        self.assertEqual(hello, result)
+
+    def test_newhadoop(self):
+        basepath = self.tempdir.name
+        ints = sorted(self.sc.newAPIHadoopFile(
+            basepath + "/sftestdata/sfint/",
+            "org.apache.hadoop.mapreduce.lib.input.SequenceFileInputFormat",
+            "org.apache.hadoop.io.IntWritable",
+            "org.apache.hadoop.io.Text").collect())
+        ei = [(1, u'aa'), (1, u'aa'), (2, u'aa'), (2, u'bb'), (2, u'bb'), (3, 
u'cc')]
+        self.assertEqual(ints, ei)
+
+        hellopath = os.path.join(SPARK_HOME, 
"python/test_support/hello/hello.txt")
+        newconf = {"mapreduce.input.fileinputformat.inputdir": hellopath}
+        hello = 
self.sc.newAPIHadoopRDD("org.apache.hadoop.mapreduce.lib.input.TextInputFormat",
+                                        "org.apache.hadoop.io.LongWritable",
+                                        "org.apache.hadoop.io.Text",
+                                        conf=newconf).collect()
+        result = [(0, u'Hello World!')]
+        self.assertEqual(hello, result)
+
+    def test_newolderror(self):
+        basepath = self.tempdir.name
+        self.assertRaises(Exception, lambda: self.sc.hadoopFile(
+            basepath + "/sftestdata/sfint/",
+            "org.apache.hadoop.mapreduce.lib.input.SequenceFileInputFormat",
+            "org.apache.hadoop.io.IntWritable",
+            "org.apache.hadoop.io.Text"))
+
+        self.assertRaises(Exception, lambda: self.sc.newAPIHadoopFile(
+            basepath + "/sftestdata/sfint/",
+            "org.apache.hadoop.mapred.SequenceFileInputFormat",
+            "org.apache.hadoop.io.IntWritable",
+            "org.apache.hadoop.io.Text"))
+
+    def test_bad_inputs(self):
+        basepath = self.tempdir.name
+        self.assertRaises(Exception, lambda: self.sc.sequenceFile(
+            basepath + "/sftestdata/sfint/",
+            "org.apache.hadoop.io.NotValidWritable",
+            "org.apache.hadoop.io.Text"))
+        self.assertRaises(Exception, lambda: self.sc.hadoopFile(
+            basepath + "/sftestdata/sfint/",
+            "org.apache.hadoop.mapred.NotValidInputFormat",
+            "org.apache.hadoop.io.IntWritable",
+            "org.apache.hadoop.io.Text"))
+        self.assertRaises(Exception, lambda: self.sc.newAPIHadoopFile(
+            basepath + "/sftestdata/sfint/",
+            "org.apache.hadoop.mapreduce.lib.input.NotValidInputFormat",
+            "org.apache.hadoop.io.IntWritable",
+            "org.apache.hadoop.io.Text"))
+
+    def test_converters(self):
+        # use of custom converters
+        basepath = self.tempdir.name
+        maps = sorted(self.sc.sequenceFile(
+            basepath + "/sftestdata/sfmap/",
+            "org.apache.hadoop.io.IntWritable",
+            "org.apache.hadoop.io.MapWritable",
+            keyConverter="org.apache.spark.api.python.TestInputKeyConverter",
+            
valueConverter="org.apache.spark.api.python.TestInputValueConverter").collect())
+        em = [(u'\x01', []),
+              (u'\x01', [3.0]),
+              (u'\x02', [1.0]),
+              (u'\x02', [1.0]),
+              (u'\x03', [2.0])]
+        self.assertEqual(maps, em)
+
+    def test_binary_files(self):
+        path = os.path.join(self.tempdir.name, "binaryfiles")
+        os.mkdir(path)
+        data = b"short binary data"
+        with open(os.path.join(path, "part-0000"), 'wb') as f:
+            f.write(data)
+        [(p, d)] = self.sc.binaryFiles(path).collect()
+        self.assertTrue(p.endswith("part-0000"))
+        self.assertEqual(d, data)
+
+    def test_binary_records(self):
+        path = os.path.join(self.tempdir.name, "binaryrecords")
+        os.mkdir(path)
+        with open(os.path.join(path, "part-0000"), 'w') as f:
+            for i in range(100):
+                f.write('%04d' % i)
+        result = self.sc.binaryRecords(path, 4).map(int).collect()
+        self.assertEqual(list(range(100)), result)
+
+
+class OutputFormatTests(ReusedPySparkTestCase):
+
+    def setUp(self):
+        self.tempdir = tempfile.NamedTemporaryFile(delete=False)
+        os.unlink(self.tempdir.name)
+
+    def tearDown(self):
+        shutil.rmtree(self.tempdir.name, ignore_errors=True)
+
+    @unittest.skipIf(sys.version >= "3", "serialize array of byte")
+    def test_sequencefiles(self):
+        basepath = self.tempdir.name
+        ei = [(1, u'aa'), (1, u'aa'), (2, u'aa'), (2, u'bb'), (2, u'bb'), (3, 
u'cc')]
+        self.sc.parallelize(ei).saveAsSequenceFile(basepath + "/sfint/")
+        ints = sorted(self.sc.sequenceFile(basepath + "/sfint/").collect())
+        self.assertEqual(ints, ei)
+
+        ed = [(1.0, u'aa'), (1.0, u'aa'), (2.0, u'aa'), (2.0, u'bb'), (2.0, 
u'bb'), (3.0, u'cc')]
+        self.sc.parallelize(ed).saveAsSequenceFile(basepath + "/sfdouble/")
+        doubles = sorted(self.sc.sequenceFile(basepath + 
"/sfdouble/").collect())
+        self.assertEqual(doubles, ed)
+
+        ebs = [(1, bytearray(b'\x00\x07spam\x08')), (2, 
bytearray(b'\x00\x07spam\x08'))]
+        self.sc.parallelize(ebs).saveAsSequenceFile(basepath + "/sfbytes/")
+        bytes = sorted(self.sc.sequenceFile(basepath + "/sfbytes/").collect())
+        self.assertEqual(bytes, ebs)
+
+        et = [(u'1', u'aa'),
+              (u'2', u'bb'),
+              (u'3', u'cc')]
+        self.sc.parallelize(et).saveAsSequenceFile(basepath + "/sftext/")
+        text = sorted(self.sc.sequenceFile(basepath + "/sftext/").collect())
+        self.assertEqual(text, et)
+
+        eb = [(1, False), (1, True), (2, False), (2, False), (2, True), (3, 
True)]
+        self.sc.parallelize(eb).saveAsSequenceFile(basepath + "/sfbool/")
+        bools = sorted(self.sc.sequenceFile(basepath + "/sfbool/").collect())
+        self.assertEqual(bools, eb)
+
+        en = [(1, None), (1, None), (2, None), (2, None), (2, None), (3, None)]
+        self.sc.parallelize(en).saveAsSequenceFile(basepath + "/sfnull/")
+        nulls = sorted(self.sc.sequenceFile(basepath + "/sfnull/").collect())
+        self.assertEqual(nulls, en)
+
+        em = [(1, {}),
+              (1, {3.0: u'bb'}),
+              (2, {1.0: u'aa'}),
+              (2, {1.0: u'cc'}),
+              (3, {2.0: u'dd'})]
+        self.sc.parallelize(em).saveAsSequenceFile(basepath + "/sfmap/")
+        maps = self.sc.sequenceFile(basepath + "/sfmap/").collect()
+        for v in maps:
+            self.assertTrue(v, em)
+
+    def test_oldhadoop(self):
+        basepath = self.tempdir.name
+        dict_data = [(1, {}),
+                     (1, {"row1": 1.0}),
+                     (2, {"row2": 2.0})]
+        self.sc.parallelize(dict_data).saveAsHadoopFile(
+            basepath + "/oldhadoop/",
+            "org.apache.hadoop.mapred.SequenceFileOutputFormat",
+            "org.apache.hadoop.io.IntWritable",
+            "org.apache.hadoop.io.MapWritable")
+        result = self.sc.hadoopFile(
+            basepath + "/oldhadoop/",
+            "org.apache.hadoop.mapred.SequenceFileInputFormat",
+            "org.apache.hadoop.io.IntWritable",
+            "org.apache.hadoop.io.MapWritable").collect()
+        for v in result:
+            self.assertTrue(v, dict_data)
+
+        conf = {
+            "mapred.output.format.class": 
"org.apache.hadoop.mapred.SequenceFileOutputFormat",
+            "mapreduce.job.output.key.class": 
"org.apache.hadoop.io.IntWritable",
+            "mapreduce.job.output.value.class": 
"org.apache.hadoop.io.MapWritable",
+            "mapreduce.output.fileoutputformat.outputdir": basepath + 
"/olddataset/"
+        }
+        self.sc.parallelize(dict_data).saveAsHadoopDataset(conf)
+        input_conf = {"mapreduce.input.fileinputformat.inputdir": basepath + 
"/olddataset/"}
+        result = self.sc.hadoopRDD(
+            "org.apache.hadoop.mapred.SequenceFileInputFormat",
+            "org.apache.hadoop.io.IntWritable",
+            "org.apache.hadoop.io.MapWritable",
+            conf=input_conf).collect()
+        for v in result:
+            self.assertTrue(v, dict_data)
+
+    def test_newhadoop(self):
+        basepath = self.tempdir.name
+        data = [(1, ""),
+                (1, "a"),
+                (2, "bcdf")]
+        self.sc.parallelize(data).saveAsNewAPIHadoopFile(
+            basepath + "/newhadoop/",
+            "org.apache.hadoop.mapreduce.lib.output.SequenceFileOutputFormat",
+            "org.apache.hadoop.io.IntWritable",
+            "org.apache.hadoop.io.Text")
+        result = sorted(self.sc.newAPIHadoopFile(
+            basepath + "/newhadoop/",
+            "org.apache.hadoop.mapreduce.lib.input.SequenceFileInputFormat",
+            "org.apache.hadoop.io.IntWritable",
+            "org.apache.hadoop.io.Text").collect())
+        self.assertEqual(result, data)
+
+        conf = {
+            "mapreduce.job.outputformat.class":
+                
"org.apache.hadoop.mapreduce.lib.output.SequenceFileOutputFormat",
+            "mapreduce.job.output.key.class": 
"org.apache.hadoop.io.IntWritable",
+            "mapreduce.job.output.value.class": "org.apache.hadoop.io.Text",
+            "mapreduce.output.fileoutputformat.outputdir": basepath + 
"/newdataset/"
+        }
+        self.sc.parallelize(data).saveAsNewAPIHadoopDataset(conf)
+        input_conf = {"mapreduce.input.fileinputformat.inputdir": basepath + 
"/newdataset/"}
+        new_dataset = sorted(self.sc.newAPIHadoopRDD(
+            "org.apache.hadoop.mapreduce.lib.input.SequenceFileInputFormat",
+            "org.apache.hadoop.io.IntWritable",
+            "org.apache.hadoop.io.Text",
+            conf=input_conf).collect())
+        self.assertEqual(new_dataset, data)
+
+    @unittest.skipIf(sys.version >= "3", "serialize of array")
+    def test_newhadoop_with_array(self):
+        basepath = self.tempdir.name
+        # use custom ArrayWritable types and converters to handle arrays
+        array_data = [(1, array('d')),
+                      (1, array('d', [1.0, 2.0, 3.0])),
+                      (2, array('d', [3.0, 4.0, 5.0]))]
+        self.sc.parallelize(array_data).saveAsNewAPIHadoopFile(
+            basepath + "/newhadoop/",
+            "org.apache.hadoop.mapreduce.lib.output.SequenceFileOutputFormat",
+            "org.apache.hadoop.io.IntWritable",
+            "org.apache.spark.api.python.DoubleArrayWritable",
+            
valueConverter="org.apache.spark.api.python.DoubleArrayToWritableConverter")
+        result = sorted(self.sc.newAPIHadoopFile(
+            basepath + "/newhadoop/",
+            "org.apache.hadoop.mapreduce.lib.input.SequenceFileInputFormat",
+            "org.apache.hadoop.io.IntWritable",
+            "org.apache.spark.api.python.DoubleArrayWritable",
+            
valueConverter="org.apache.spark.api.python.WritableToDoubleArrayConverter").collect())
+        self.assertEqual(result, array_data)
+
+        conf = {
+            "mapreduce.job.outputformat.class":
+                
"org.apache.hadoop.mapreduce.lib.output.SequenceFileOutputFormat",
+            "mapreduce.job.output.key.class": 
"org.apache.hadoop.io.IntWritable",
+            "mapreduce.job.output.value.class": 
"org.apache.spark.api.python.DoubleArrayWritable",
+            "mapreduce.output.fileoutputformat.outputdir": basepath + 
"/newdataset/"
+        }
+        self.sc.parallelize(array_data).saveAsNewAPIHadoopDataset(
+            conf,
+            
valueConverter="org.apache.spark.api.python.DoubleArrayToWritableConverter")
+        input_conf = {"mapreduce.input.fileinputformat.inputdir": basepath + 
"/newdataset/"}
+        new_dataset = sorted(self.sc.newAPIHadoopRDD(
+            "org.apache.hadoop.mapreduce.lib.input.SequenceFileInputFormat",
+            "org.apache.hadoop.io.IntWritable",
+            "org.apache.spark.api.python.DoubleArrayWritable",
+            
valueConverter="org.apache.spark.api.python.WritableToDoubleArrayConverter",
+            conf=input_conf).collect())
+        self.assertEqual(new_dataset, array_data)
+
+    def test_newolderror(self):
+        basepath = self.tempdir.name
+        rdd = self.sc.parallelize(range(1, 4)).map(lambda x: (x, "a" * x))
+        self.assertRaises(Exception, lambda: rdd.saveAsHadoopFile(
+            basepath + "/newolderror/saveAsHadoopFile/",
+            "org.apache.hadoop.mapreduce.lib.output.SequenceFileOutputFormat"))
+        self.assertRaises(Exception, lambda: rdd.saveAsNewAPIHadoopFile(
+            basepath + "/newolderror/saveAsNewAPIHadoopFile/",
+            "org.apache.hadoop.mapred.SequenceFileOutputFormat"))
+
+    def test_bad_inputs(self):
+        basepath = self.tempdir.name
+        rdd = self.sc.parallelize(range(1, 4)).map(lambda x: (x, "a" * x))
+        self.assertRaises(Exception, lambda: rdd.saveAsHadoopFile(
+            basepath + "/badinputs/saveAsHadoopFile/",
+            "org.apache.hadoop.mapred.NotValidOutputFormat"))
+        self.assertRaises(Exception, lambda: rdd.saveAsNewAPIHadoopFile(
+            basepath + "/badinputs/saveAsNewAPIHadoopFile/",
+            "org.apache.hadoop.mapreduce.lib.output.NotValidOutputFormat"))
+
+    def test_converters(self):
+        # use of custom converters
+        basepath = self.tempdir.name
+        data = [(1, {3.0: u'bb'}),
+                (2, {1.0: u'aa'}),
+                (3, {2.0: u'dd'})]
+        self.sc.parallelize(data).saveAsNewAPIHadoopFile(
+            basepath + "/converters/",
+            "org.apache.hadoop.mapreduce.lib.output.SequenceFileOutputFormat",
+            keyConverter="org.apache.spark.api.python.TestOutputKeyConverter",
+            
valueConverter="org.apache.spark.api.python.TestOutputValueConverter")
+        converted = sorted(self.sc.sequenceFile(basepath + 
"/converters/").collect())
+        expected = [(u'1', 3.0),
+                    (u'2', 1.0),
+                    (u'3', 2.0)]
+        self.assertEqual(converted, expected)
+
+    def test_reserialization(self):
+        basepath = self.tempdir.name
+        x = range(1, 5)
+        y = range(1001, 1005)
+        data = list(zip(x, y))
+        rdd = self.sc.parallelize(x).zip(self.sc.parallelize(y))
+        rdd.saveAsSequenceFile(basepath + "/reserialize/sequence")
+        result1 = sorted(self.sc.sequenceFile(basepath + 
"/reserialize/sequence").collect())
+        self.assertEqual(result1, data)
+
+        rdd.saveAsHadoopFile(
+            basepath + "/reserialize/hadoop",
+            "org.apache.hadoop.mapred.SequenceFileOutputFormat")
+        result2 = sorted(self.sc.sequenceFile(basepath + 
"/reserialize/hadoop").collect())
+        self.assertEqual(result2, data)
+
+        rdd.saveAsNewAPIHadoopFile(
+            basepath + "/reserialize/newhadoop",
+            "org.apache.hadoop.mapreduce.lib.output.SequenceFileOutputFormat")
+        result3 = sorted(self.sc.sequenceFile(basepath + 
"/reserialize/newhadoop").collect())
+        self.assertEqual(result3, data)
+
+        conf4 = {
+            "mapred.output.format.class": 
"org.apache.hadoop.mapred.SequenceFileOutputFormat",
+            "mapreduce.job.output.key.class": 
"org.apache.hadoop.io.IntWritable",
+            "mapreduce.job.output.value.class": 
"org.apache.hadoop.io.IntWritable",
+            "mapreduce.output.fileoutputformat.outputdir": basepath + 
"/reserialize/dataset"}
+        rdd.saveAsHadoopDataset(conf4)
+        result4 = sorted(self.sc.sequenceFile(basepath + 
"/reserialize/dataset").collect())
+        self.assertEqual(result4, data)
+
+        conf5 = {"mapreduce.job.outputformat.class":
+                 
"org.apache.hadoop.mapreduce.lib.output.SequenceFileOutputFormat",
+                 "mapreduce.job.output.key.class": 
"org.apache.hadoop.io.IntWritable",
+                 "mapreduce.job.output.value.class": 
"org.apache.hadoop.io.IntWritable",
+                 "mapreduce.output.fileoutputformat.outputdir": basepath + 
"/reserialize/newdataset"
+                 }
+        rdd.saveAsNewAPIHadoopDataset(conf5)
+        result5 = sorted(self.sc.sequenceFile(basepath + 
"/reserialize/newdataset").collect())
+        self.assertEqual(result5, data)
+
+    def test_malformed_RDD(self):
+        basepath = self.tempdir.name
+        # non-batch-serialized RDD[[(K, V)]] should be rejected
+        data = [[(1, "a")], [(2, "aa")], [(3, "aaa")]]
+        rdd = self.sc.parallelize(data, len(data))
+        self.assertRaises(Exception, lambda: rdd.saveAsSequenceFile(
+            basepath + "/malformed/sequence"))
+
+
+if __name__ == "__main__":
+    from pyspark.tests.test_readwrite import *
+
+    try:
+        import xmlrunner
+        testRunner = xmlrunner.XMLTestRunner(output='target/test-reports')
+    except ImportError:
+        testRunner = None
+    unittest.main(testRunner=testRunner, verbosity=2)

http://git-wip-us.apache.org/repos/asf/spark/blob/03306a6d/python/pyspark/tests/test_serializers.py
----------------------------------------------------------------------
diff --git a/python/pyspark/tests/test_serializers.py 
b/python/pyspark/tests/test_serializers.py
new file mode 100644
index 0000000..bce9406
--- /dev/null
+++ b/python/pyspark/tests/test_serializers.py
@@ -0,0 +1,237 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements.  See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License.  You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+import math
+import sys
+import unittest
+
+from pyspark import serializers
+from pyspark.serializers import *
+from pyspark.serializers import CloudPickleSerializer, CompressedSerializer, \
+    AutoBatchedSerializer, BatchedSerializer, AutoSerializer, NoOpSerializer, 
PairDeserializer, \
+    FlattenedValuesSerializer, CartesianDeserializer
+from pyspark.testing.utils import PySparkTestCase, read_int, write_int, 
ByteArrayOutput, \
+    have_numpy, have_scipy
+
+
+class SerializationTestCase(unittest.TestCase):
+
+    def test_namedtuple(self):
+        from collections import namedtuple
+        from pickle import dumps, loads
+        P = namedtuple("P", "x y")
+        p1 = P(1, 3)
+        p2 = loads(dumps(p1, 2))
+        self.assertEqual(p1, p2)
+
+        from pyspark.cloudpickle import dumps
+        P2 = loads(dumps(P))
+        p3 = P2(1, 3)
+        self.assertEqual(p1, p3)
+
+    def test_itemgetter(self):
+        from operator import itemgetter
+        ser = CloudPickleSerializer()
+        d = range(10)
+        getter = itemgetter(1)
+        getter2 = ser.loads(ser.dumps(getter))
+        self.assertEqual(getter(d), getter2(d))
+
+        getter = itemgetter(0, 3)
+        getter2 = ser.loads(ser.dumps(getter))
+        self.assertEqual(getter(d), getter2(d))
+
+    def test_function_module_name(self):
+        ser = CloudPickleSerializer()
+        func = lambda x: x
+        func2 = ser.loads(ser.dumps(func))
+        self.assertEqual(func.__module__, func2.__module__)
+
+    def test_attrgetter(self):
+        from operator import attrgetter
+        ser = CloudPickleSerializer()
+
+        class C(object):
+            def __getattr__(self, item):
+                return item
+        d = C()
+        getter = attrgetter("a")
+        getter2 = ser.loads(ser.dumps(getter))
+        self.assertEqual(getter(d), getter2(d))
+        getter = attrgetter("a", "b")
+        getter2 = ser.loads(ser.dumps(getter))
+        self.assertEqual(getter(d), getter2(d))
+
+        d.e = C()
+        getter = attrgetter("e.a")
+        getter2 = ser.loads(ser.dumps(getter))
+        self.assertEqual(getter(d), getter2(d))
+        getter = attrgetter("e.a", "e.b")
+        getter2 = ser.loads(ser.dumps(getter))
+        self.assertEqual(getter(d), getter2(d))
+
+    # Regression test for SPARK-3415
+    def test_pickling_file_handles(self):
+        # to be corrected with SPARK-11160
+        try:
+            import xmlrunner
+        except ImportError:
+            ser = CloudPickleSerializer()
+            out1 = sys.stderr
+            out2 = ser.loads(ser.dumps(out1))
+            self.assertEqual(out1, out2)
+
+    def test_func_globals(self):
+
+        class Unpicklable(object):
+            def __reduce__(self):
+                raise Exception("not picklable")
+
+        global exit
+        exit = Unpicklable()
+
+        ser = CloudPickleSerializer()
+        self.assertRaises(Exception, lambda: ser.dumps(exit))
+
+        def foo():
+            sys.exit(0)
+
+        self.assertTrue("exit" in foo.__code__.co_names)
+        ser.dumps(foo)
+
+    def test_compressed_serializer(self):
+        ser = CompressedSerializer(PickleSerializer())
+        try:
+            from StringIO import StringIO
+        except ImportError:
+            from io import BytesIO as StringIO
+        io = StringIO()
+        ser.dump_stream(["abc", u"123", range(5)], io)
+        io.seek(0)
+        self.assertEqual(["abc", u"123", range(5)], list(ser.load_stream(io)))
+        ser.dump_stream(range(1000), io)
+        io.seek(0)
+        self.assertEqual(["abc", u"123", range(5)] + list(range(1000)), 
list(ser.load_stream(io)))
+        io.close()
+
+    def test_hash_serializer(self):
+        hash(NoOpSerializer())
+        hash(UTF8Deserializer())
+        hash(PickleSerializer())
+        hash(MarshalSerializer())
+        hash(AutoSerializer())
+        hash(BatchedSerializer(PickleSerializer()))
+        hash(AutoBatchedSerializer(MarshalSerializer()))
+        hash(PairDeserializer(NoOpSerializer(), UTF8Deserializer()))
+        hash(CartesianDeserializer(NoOpSerializer(), UTF8Deserializer()))
+        hash(CompressedSerializer(PickleSerializer()))
+        hash(FlattenedValuesSerializer(PickleSerializer()))
+
+
+@unittest.skipIf(not have_scipy, "SciPy not installed")
+class SciPyTests(PySparkTestCase):
+
+    """General PySpark tests that depend on scipy """
+
+    def test_serialize(self):
+        from scipy.special import gammaln
+
+        x = range(1, 5)
+        expected = list(map(gammaln, x))
+        observed = self.sc.parallelize(x).map(gammaln).collect()
+        self.assertEqual(expected, observed)
+
+
+@unittest.skipIf(not have_numpy, "NumPy not installed")
+class NumPyTests(PySparkTestCase):
+
+    """General PySpark tests that depend on numpy """
+
+    def test_statcounter_array(self):
+        import numpy as np
+
+        x = self.sc.parallelize([np.array([1.0, 1.0]), np.array([2.0, 2.0]), 
np.array([3.0, 3.0])])
+        s = x.stats()
+        self.assertSequenceEqual([2.0, 2.0], s.mean().tolist())
+        self.assertSequenceEqual([1.0, 1.0], s.min().tolist())
+        self.assertSequenceEqual([3.0, 3.0], s.max().tolist())
+        self.assertSequenceEqual([1.0, 1.0], s.sampleStdev().tolist())
+
+        stats_dict = s.asDict()
+        self.assertEqual(3, stats_dict['count'])
+        self.assertSequenceEqual([2.0, 2.0], stats_dict['mean'].tolist())
+        self.assertSequenceEqual([1.0, 1.0], stats_dict['min'].tolist())
+        self.assertSequenceEqual([3.0, 3.0], stats_dict['max'].tolist())
+        self.assertSequenceEqual([6.0, 6.0], stats_dict['sum'].tolist())
+        self.assertSequenceEqual([1.0, 1.0], stats_dict['stdev'].tolist())
+        self.assertSequenceEqual([1.0, 1.0], stats_dict['variance'].tolist())
+
+        stats_sample_dict = s.asDict(sample=True)
+        self.assertEqual(3, stats_dict['count'])
+        self.assertSequenceEqual([2.0, 2.0], 
stats_sample_dict['mean'].tolist())
+        self.assertSequenceEqual([1.0, 1.0], stats_sample_dict['min'].tolist())
+        self.assertSequenceEqual([3.0, 3.0], stats_sample_dict['max'].tolist())
+        self.assertSequenceEqual([6.0, 6.0], stats_sample_dict['sum'].tolist())
+        self.assertSequenceEqual(
+            [0.816496580927726, 0.816496580927726], 
stats_sample_dict['stdev'].tolist())
+        self.assertSequenceEqual(
+            [0.6666666666666666, 0.6666666666666666], 
stats_sample_dict['variance'].tolist())
+
+
+class SerializersTest(unittest.TestCase):
+
+    def test_chunked_stream(self):
+        original_bytes = bytearray(range(100))
+        for data_length in [1, 10, 100]:
+            for buffer_length in [1, 2, 3, 5, 20, 99, 100, 101, 500]:
+                dest = ByteArrayOutput()
+                stream_out = serializers.ChunkedStream(dest, buffer_length)
+                stream_out.write(original_bytes[:data_length])
+                stream_out.close()
+                num_chunks = int(math.ceil(float(data_length) / buffer_length))
+                # length for each chunk, and a final -1 at the very end
+                exp_size = (num_chunks + 1) * 4 + data_length
+                self.assertEqual(len(dest.buffer), exp_size)
+                dest_pos = 0
+                data_pos = 0
+                for chunk_idx in range(num_chunks):
+                    chunk_length = read_int(dest.buffer[dest_pos:(dest_pos + 
4)])
+                    if chunk_idx == num_chunks - 1:
+                        exp_length = data_length % buffer_length
+                        if exp_length == 0:
+                            exp_length = buffer_length
+                    else:
+                        exp_length = buffer_length
+                    self.assertEqual(chunk_length, exp_length)
+                    dest_pos += 4
+                    dest_chunk = dest.buffer[dest_pos:dest_pos + chunk_length]
+                    orig_chunk = original_bytes[data_pos:data_pos + 
chunk_length]
+                    self.assertEqual(dest_chunk, orig_chunk)
+                    dest_pos += chunk_length
+                    data_pos += chunk_length
+                # ends with a -1
+                self.assertEqual(dest.buffer[-4:], write_int(-1))
+
+
+if __name__ == "__main__":
+    from pyspark.tests.test_serializers import *
+
+    try:
+        import xmlrunner
+        testRunner = xmlrunner.XMLTestRunner(output='target/test-reports')
+    except ImportError:
+        testRunner = None
+    unittest.main(testRunner=testRunner, verbosity=2)

http://git-wip-us.apache.org/repos/asf/spark/blob/03306a6d/python/pyspark/tests/test_shuffle.py
----------------------------------------------------------------------
diff --git a/python/pyspark/tests/test_shuffle.py 
b/python/pyspark/tests/test_shuffle.py
new file mode 100644
index 0000000..0489426
--- /dev/null
+++ b/python/pyspark/tests/test_shuffle.py
@@ -0,0 +1,181 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements.  See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License.  You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+import random
+import sys
+import unittest
+
+from py4j.protocol import Py4JJavaError
+
+from pyspark import shuffle, PickleSerializer, SparkConf, SparkContext
+from pyspark.shuffle import Aggregator, ExternalMerger, ExternalSorter
+
+if sys.version_info[0] >= 3:
+    xrange = range
+
+
+class MergerTests(unittest.TestCase):
+
+    def setUp(self):
+        self.N = 1 << 12
+        self.l = [i for i in xrange(self.N)]
+        self.data = list(zip(self.l, self.l))
+        self.agg = Aggregator(lambda x: [x],
+                              lambda x, y: x.append(y) or x,
+                              lambda x, y: x.extend(y) or x)
+
+    def test_small_dataset(self):
+        m = ExternalMerger(self.agg, 1000)
+        m.mergeValues(self.data)
+        self.assertEqual(m.spills, 0)
+        self.assertEqual(sum(sum(v) for k, v in m.items()),
+                         sum(xrange(self.N)))
+
+        m = ExternalMerger(self.agg, 1000)
+        m.mergeCombiners(map(lambda x_y1: (x_y1[0], [x_y1[1]]), self.data))
+        self.assertEqual(m.spills, 0)
+        self.assertEqual(sum(sum(v) for k, v in m.items()),
+                         sum(xrange(self.N)))
+
+    def test_medium_dataset(self):
+        m = ExternalMerger(self.agg, 20)
+        m.mergeValues(self.data)
+        self.assertTrue(m.spills >= 1)
+        self.assertEqual(sum(sum(v) for k, v in m.items()),
+                         sum(xrange(self.N)))
+
+        m = ExternalMerger(self.agg, 10)
+        m.mergeCombiners(map(lambda x_y2: (x_y2[0], [x_y2[1]]), self.data * 3))
+        self.assertTrue(m.spills >= 1)
+        self.assertEqual(sum(sum(v) for k, v in m.items()),
+                         sum(xrange(self.N)) * 3)
+
+    def test_huge_dataset(self):
+        m = ExternalMerger(self.agg, 5, partitions=3)
+        m.mergeCombiners(map(lambda k_v: (k_v[0], [str(k_v[1])]), self.data * 
10))
+        self.assertTrue(m.spills >= 1)
+        self.assertEqual(sum(len(v) for k, v in m.items()),
+                         self.N * 10)
+        m._cleanup()
+
+    def test_group_by_key(self):
+
+        def gen_data(N, step):
+            for i in range(1, N + 1, step):
+                for j in range(i):
+                    yield (i, [j])
+
+        def gen_gs(N, step=1):
+            return shuffle.GroupByKey(gen_data(N, step))
+
+        self.assertEqual(1, len(list(gen_gs(1))))
+        self.assertEqual(2, len(list(gen_gs(2))))
+        self.assertEqual(100, len(list(gen_gs(100))))
+        self.assertEqual(list(range(1, 101)), [k for k, _ in gen_gs(100)])
+        self.assertTrue(all(list(range(k)) == list(vs) for k, vs in 
gen_gs(100)))
+
+        for k, vs in gen_gs(50002, 10000):
+            self.assertEqual(k, len(vs))
+            self.assertEqual(list(range(k)), list(vs))
+
+        ser = PickleSerializer()
+        l = ser.loads(ser.dumps(list(gen_gs(50002, 30000))))
+        for k, vs in l:
+            self.assertEqual(k, len(vs))
+            self.assertEqual(list(range(k)), list(vs))
+
+    def test_stopiteration_is_raised(self):
+
+        def stopit(*args, **kwargs):
+            raise StopIteration()
+
+        def legit_create_combiner(x):
+            return [x]
+
+        def legit_merge_value(x, y):
+            return x.append(y) or x
+
+        def legit_merge_combiners(x, y):
+            return x.extend(y) or x
+
+        data = [(x % 2, x) for x in range(100)]
+
+        # wrong create combiner
+        m = ExternalMerger(Aggregator(stopit, legit_merge_value, 
legit_merge_combiners), 20)
+        with self.assertRaises((Py4JJavaError, RuntimeError)) as cm:
+            m.mergeValues(data)
+
+        # wrong merge value
+        m = ExternalMerger(Aggregator(legit_create_combiner, stopit, 
legit_merge_combiners), 20)
+        with self.assertRaises((Py4JJavaError, RuntimeError)) as cm:
+            m.mergeValues(data)
+
+        # wrong merge combiners
+        m = ExternalMerger(Aggregator(legit_create_combiner, 
legit_merge_value, stopit), 20)
+        with self.assertRaises((Py4JJavaError, RuntimeError)) as cm:
+            m.mergeCombiners(map(lambda x_y1: (x_y1[0], [x_y1[1]]), data))
+
+
+class SorterTests(unittest.TestCase):
+    def test_in_memory_sort(self):
+        l = list(range(1024))
+        random.shuffle(l)
+        sorter = ExternalSorter(1024)
+        self.assertEqual(sorted(l), list(sorter.sorted(l)))
+        self.assertEqual(sorted(l, reverse=True), list(sorter.sorted(l, 
reverse=True)))
+        self.assertEqual(sorted(l, key=lambda x: -x), list(sorter.sorted(l, 
key=lambda x: -x)))
+        self.assertEqual(sorted(l, key=lambda x: -x, reverse=True),
+                         list(sorter.sorted(l, key=lambda x: -x, 
reverse=True)))
+
+    def test_external_sort(self):
+        class CustomizedSorter(ExternalSorter):
+            def _next_limit(self):
+                return self.memory_limit
+        l = list(range(1024))
+        random.shuffle(l)
+        sorter = CustomizedSorter(1)
+        self.assertEqual(sorted(l), list(sorter.sorted(l)))
+        self.assertGreater(shuffle.DiskBytesSpilled, 0)
+        last = shuffle.DiskBytesSpilled
+        self.assertEqual(sorted(l, reverse=True), list(sorter.sorted(l, 
reverse=True)))
+        self.assertGreater(shuffle.DiskBytesSpilled, last)
+        last = shuffle.DiskBytesSpilled
+        self.assertEqual(sorted(l, key=lambda x: -x), list(sorter.sorted(l, 
key=lambda x: -x)))
+        self.assertGreater(shuffle.DiskBytesSpilled, last)
+        last = shuffle.DiskBytesSpilled
+        self.assertEqual(sorted(l, key=lambda x: -x, reverse=True),
+                         list(sorter.sorted(l, key=lambda x: -x, 
reverse=True)))
+        self.assertGreater(shuffle.DiskBytesSpilled, last)
+
+    def test_external_sort_in_rdd(self):
+        conf = SparkConf().set("spark.python.worker.memory", "1m")
+        sc = SparkContext(conf=conf)
+        l = list(range(10240))
+        random.shuffle(l)
+        rdd = sc.parallelize(l, 4)
+        self.assertEqual(sorted(l), rdd.sortBy(lambda x: x).collect())
+        sc.stop()
+
+
+if __name__ == "__main__":
+    from pyspark.tests.test_shuffle import *
+
+    try:
+        import xmlrunner
+        testRunner = xmlrunner.XMLTestRunner(output='target/test-reports')
+    except ImportError:
+        testRunner = None
+    unittest.main(testRunner=testRunner, verbosity=2)

http://git-wip-us.apache.org/repos/asf/spark/blob/03306a6d/python/pyspark/tests/test_taskcontext.py
----------------------------------------------------------------------
diff --git a/python/pyspark/tests/test_taskcontext.py 
b/python/pyspark/tests/test_taskcontext.py
new file mode 100644
index 0000000..b3a9674
--- /dev/null
+++ b/python/pyspark/tests/test_taskcontext.py
@@ -0,0 +1,161 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements.  See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License.  You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+import random
+import sys
+import time
+
+from pyspark import SparkContext, TaskContext, BarrierTaskContext
+from pyspark.testing.utils import PySparkTestCase
+
+
+class TaskContextTests(PySparkTestCase):
+
+    def setUp(self):
+        self._old_sys_path = list(sys.path)
+        class_name = self.__class__.__name__
+        # Allow retries even though they are normally disabled in local mode
+        self.sc = SparkContext('local[4, 2]', class_name)
+
+    def test_stage_id(self):
+        """Test the stage ids are available and incrementing as expected."""
+        rdd = self.sc.parallelize(range(10))
+        stage1 = rdd.map(lambda x: TaskContext.get().stageId()).take(1)[0]
+        stage2 = rdd.map(lambda x: TaskContext.get().stageId()).take(1)[0]
+        # Test using the constructor directly rather than the get()
+        stage3 = rdd.map(lambda x: TaskContext().stageId()).take(1)[0]
+        self.assertEqual(stage1 + 1, stage2)
+        self.assertEqual(stage1 + 2, stage3)
+        self.assertEqual(stage2 + 1, stage3)
+
+    def test_partition_id(self):
+        """Test the partition id."""
+        rdd1 = self.sc.parallelize(range(10), 1)
+        rdd2 = self.sc.parallelize(range(10), 2)
+        pids1 = rdd1.map(lambda x: TaskContext.get().partitionId()).collect()
+        pids2 = rdd2.map(lambda x: TaskContext.get().partitionId()).collect()
+        self.assertEqual(0, pids1[0])
+        self.assertEqual(0, pids1[9])
+        self.assertEqual(0, pids2[0])
+        self.assertEqual(1, pids2[9])
+
+    def test_attempt_number(self):
+        """Verify the attempt numbers are correctly reported."""
+        rdd = self.sc.parallelize(range(10))
+        # Verify a simple job with no failures
+        attempt_numbers = rdd.map(lambda x: 
TaskContext.get().attemptNumber()).collect()
+        map(lambda attempt: self.assertEqual(0, attempt), attempt_numbers)
+
+        def fail_on_first(x):
+            """Fail on the first attempt so we get a positive attempt number"""
+            tc = TaskContext.get()
+            attempt_number = tc.attemptNumber()
+            partition_id = tc.partitionId()
+            attempt_id = tc.taskAttemptId()
+            if attempt_number == 0 and partition_id == 0:
+                raise Exception("Failing on first attempt")
+            else:
+                return [x, partition_id, attempt_number, attempt_id]
+        result = rdd.map(fail_on_first).collect()
+        # We should re-submit the first partition to it but other partitions 
should be attempt 0
+        self.assertEqual([0, 0, 1], result[0][0:3])
+        self.assertEqual([9, 3, 0], result[9][0:3])
+        first_partition = filter(lambda x: x[1] == 0, result)
+        map(lambda x: self.assertEqual(1, x[2]), first_partition)
+        other_partitions = filter(lambda x: x[1] != 0, result)
+        map(lambda x: self.assertEqual(0, x[2]), other_partitions)
+        # The task attempt id should be different
+        self.assertTrue(result[0][3] != result[9][3])
+
+    def test_tc_on_driver(self):
+        """Verify that getting the TaskContext on the driver returns None."""
+        tc = TaskContext.get()
+        self.assertTrue(tc is None)
+
+    def test_get_local_property(self):
+        """Verify that local properties set on the driver are available in 
TaskContext."""
+        key = "testkey"
+        value = "testvalue"
+        self.sc.setLocalProperty(key, value)
+        try:
+            rdd = self.sc.parallelize(range(1), 1)
+            prop1 = rdd.map(lambda _: 
TaskContext.get().getLocalProperty(key)).collect()[0]
+            self.assertEqual(prop1, value)
+            prop2 = rdd.map(lambda _: 
TaskContext.get().getLocalProperty("otherkey")).collect()[0]
+            self.assertTrue(prop2 is None)
+        finally:
+            self.sc.setLocalProperty(key, None)
+
+    def test_barrier(self):
+        """
+        Verify that BarrierTaskContext.barrier() performs global sync among 
all barrier tasks
+        within a stage.
+        """
+        rdd = self.sc.parallelize(range(10), 4)
+
+        def f(iterator):
+            yield sum(iterator)
+
+        def context_barrier(x):
+            tc = BarrierTaskContext.get()
+            time.sleep(random.randint(1, 10))
+            tc.barrier()
+            return time.time()
+
+        times = rdd.barrier().mapPartitions(f).map(context_barrier).collect()
+        self.assertTrue(max(times) - min(times) < 1)
+
+    def test_barrier_with_python_worker_reuse(self):
+        """
+        Verify that BarrierTaskContext.barrier() with reused python worker.
+        """
+        self.sc._conf.set("spark.python.work.reuse", "true")
+        rdd = self.sc.parallelize(range(4), 4)
+        # start a normal job first to start all worker
+        result = rdd.map(lambda x: x ** 2).collect()
+        self.assertEqual([0, 1, 4, 9], result)
+        # make sure `spark.python.work.reuse=true`
+        self.assertEqual(self.sc._conf.get("spark.python.work.reuse"), "true")
+
+        # worker will be reused in this barrier job
+        self.test_barrier()
+
+    def test_barrier_infos(self):
+        """
+        Verify that BarrierTaskContext.getTaskInfos() returns a list of all 
task infos in the
+        barrier stage.
+        """
+        rdd = self.sc.parallelize(range(10), 4)
+
+        def f(iterator):
+            yield sum(iterator)
+
+        taskInfos = rdd.barrier().mapPartitions(f).map(lambda x: 
BarrierTaskContext.get()
+                                                       
.getTaskInfos()).collect()
+        self.assertTrue(len(taskInfos) == 4)
+        self.assertTrue(len(taskInfos[0]) == 4)
+
+
+if __name__ == "__main__":
+    import unittest
+    from pyspark.tests.test_taskcontext import *
+
+    try:
+        import xmlrunner
+        testRunner = xmlrunner.XMLTestRunner(output='target/test-reports')
+    except ImportError:
+        testRunner = None
+    unittest.main(testRunner=testRunner, verbosity=2)

http://git-wip-us.apache.org/repos/asf/spark/blob/03306a6d/python/pyspark/tests/test_util.py
----------------------------------------------------------------------
diff --git a/python/pyspark/tests/test_util.py 
b/python/pyspark/tests/test_util.py
new file mode 100644
index 0000000..11cda8f
--- /dev/null
+++ b/python/pyspark/tests/test_util.py
@@ -0,0 +1,86 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements.  See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License.  You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+import unittest
+
+from py4j.protocol import Py4JJavaError
+
+from pyspark import keyword_only
+from pyspark.testing.utils import PySparkTestCase
+
+
+class KeywordOnlyTests(unittest.TestCase):
+    class Wrapped(object):
+        @keyword_only
+        def set(self, x=None, y=None):
+            if "x" in self._input_kwargs:
+                self._x = self._input_kwargs["x"]
+            if "y" in self._input_kwargs:
+                self._y = self._input_kwargs["y"]
+            return x, y
+
+    def test_keywords(self):
+        w = self.Wrapped()
+        x, y = w.set(y=1)
+        self.assertEqual(y, 1)
+        self.assertEqual(y, w._y)
+        self.assertIsNone(x)
+        self.assertFalse(hasattr(w, "_x"))
+
+    def test_non_keywords(self):
+        w = self.Wrapped()
+        self.assertRaises(TypeError, lambda: w.set(0, y=1))
+
+    def test_kwarg_ownership(self):
+        # test _input_kwargs is owned by each class instance and not a shared 
static variable
+        class Setter(object):
+            @keyword_only
+            def set(self, x=None, other=None, other_x=None):
+                if "other" in self._input_kwargs:
+                    
self._input_kwargs["other"].set(x=self._input_kwargs["other_x"])
+                self._x = self._input_kwargs["x"]
+
+        a = Setter()
+        b = Setter()
+        a.set(x=1, other=b, other_x=2)
+        self.assertEqual(a._x, 1)
+        self.assertEqual(b._x, 2)
+
+
+class UtilTests(PySparkTestCase):
+    def test_py4j_exception_message(self):
+        from pyspark.util import _exception_message
+
+        with self.assertRaises(Py4JJavaError) as context:
+            # This attempts java.lang.String(null) which throws an NPE.
+            self.sc._jvm.java.lang.String(None)
+
+        self.assertTrue('NullPointerException' in 
_exception_message(context.exception))
+
+    def test_parsing_version_string(self):
+        from pyspark.util import VersionUtils
+        self.assertRaises(ValueError, lambda: 
VersionUtils.majorMinorVersion("abced"))
+
+
+if __name__ == "__main__":
+    from pyspark.tests.test_util import *
+
+    try:
+        import xmlrunner
+        testRunner = xmlrunner.XMLTestRunner(output='target/test-reports')
+    except ImportError:
+        testRunner = None
+    unittest.main(testRunner=testRunner, verbosity=2)

http://git-wip-us.apache.org/repos/asf/spark/blob/03306a6d/python/pyspark/tests/test_worker.py
----------------------------------------------------------------------
diff --git a/python/pyspark/tests/test_worker.py 
b/python/pyspark/tests/test_worker.py
new file mode 100644
index 0000000..a33b77d
--- /dev/null
+++ b/python/pyspark/tests/test_worker.py
@@ -0,0 +1,157 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements.  See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License.  You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+import os
+import sys
+import tempfile
+import threading
+import time
+
+from py4j.protocol import Py4JJavaError
+
+from pyspark.testing.utils import ReusedPySparkTestCase, QuietTest
+
+if sys.version_info[0] >= 3:
+    xrange = range
+
+
+class WorkerTests(ReusedPySparkTestCase):
+    def test_cancel_task(self):
+        temp = tempfile.NamedTemporaryFile(delete=True)
+        temp.close()
+        path = temp.name
+
+        def sleep(x):
+            import os
+            import time
+            with open(path, 'w') as f:
+                f.write("%d %d" % (os.getppid(), os.getpid()))
+            time.sleep(100)
+
+        # start job in background thread
+        def run():
+            try:
+                self.sc.parallelize(range(1), 1).foreach(sleep)
+            except Exception:
+                pass
+        import threading
+        t = threading.Thread(target=run)
+        t.daemon = True
+        t.start()
+
+        daemon_pid, worker_pid = 0, 0
+        while True:
+            if os.path.exists(path):
+                with open(path) as f:
+                    data = f.read().split(' ')
+                daemon_pid, worker_pid = map(int, data)
+                break
+            time.sleep(0.1)
+
+        # cancel jobs
+        self.sc.cancelAllJobs()
+        t.join()
+
+        for i in range(50):
+            try:
+                os.kill(worker_pid, 0)
+                time.sleep(0.1)
+            except OSError:
+                break  # worker was killed
+        else:
+            self.fail("worker has not been killed after 5 seconds")
+
+        try:
+            os.kill(daemon_pid, 0)
+        except OSError:
+            self.fail("daemon had been killed")
+
+        # run a normal job
+        rdd = self.sc.parallelize(xrange(100), 1)
+        self.assertEqual(100, rdd.map(str).count())
+
+    def test_after_exception(self):
+        def raise_exception(_):
+            raise Exception()
+        rdd = self.sc.parallelize(xrange(100), 1)
+        with QuietTest(self.sc):
+            self.assertRaises(Exception, lambda: rdd.foreach(raise_exception))
+        self.assertEqual(100, rdd.map(str).count())
+
+    def test_after_jvm_exception(self):
+        tempFile = tempfile.NamedTemporaryFile(delete=False)
+        tempFile.write(b"Hello World!")
+        tempFile.close()
+        data = self.sc.textFile(tempFile.name, 1)
+        filtered_data = data.filter(lambda x: True)
+        self.assertEqual(1, filtered_data.count())
+        os.unlink(tempFile.name)
+        with QuietTest(self.sc):
+            self.assertRaises(Exception, lambda: filtered_data.count())
+
+        rdd = self.sc.parallelize(xrange(100), 1)
+        self.assertEqual(100, rdd.map(str).count())
+
+    def test_accumulator_when_reuse_worker(self):
+        from pyspark.accumulators import INT_ACCUMULATOR_PARAM
+        acc1 = self.sc.accumulator(0, INT_ACCUMULATOR_PARAM)
+        self.sc.parallelize(xrange(100), 20).foreach(lambda x: acc1.add(x))
+        self.assertEqual(sum(range(100)), acc1.value)
+
+        acc2 = self.sc.accumulator(0, INT_ACCUMULATOR_PARAM)
+        self.sc.parallelize(xrange(100), 20).foreach(lambda x: acc2.add(x))
+        self.assertEqual(sum(range(100)), acc2.value)
+        self.assertEqual(sum(range(100)), acc1.value)
+
+    def test_reuse_worker_after_take(self):
+        rdd = self.sc.parallelize(xrange(100000), 1)
+        self.assertEqual(0, rdd.first())
+
+        def count():
+            try:
+                rdd.count()
+            except Exception:
+                pass
+
+        t = threading.Thread(target=count)
+        t.daemon = True
+        t.start()
+        t.join(5)
+        self.assertTrue(not t.isAlive())
+        self.assertEqual(100000, rdd.count())
+
+    def test_with_different_versions_of_python(self):
+        rdd = self.sc.parallelize(range(10))
+        rdd.count()
+        version = self.sc.pythonVer
+        self.sc.pythonVer = "2.0"
+        try:
+            with QuietTest(self.sc):
+                self.assertRaises(Py4JJavaError, lambda: rdd.count())
+        finally:
+            self.sc.pythonVer = version
+
+
+if __name__ == "__main__":
+    import unittest
+    from pyspark.tests.test_worker import *
+
+    try:
+        import xmlrunner
+        testRunner = xmlrunner.XMLTestRunner(output='target/test-reports')
+    except ImportError:
+        testRunner = None
+    unittest.main(testRunner=testRunner, verbosity=2)


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to