http://git-wip-us.apache.org/repos/asf/spark/blob/03306a6d/python/pyspark/tests/__init__.py ---------------------------------------------------------------------- diff --git a/python/pyspark/tests/__init__.py b/python/pyspark/tests/__init__.py new file mode 100644 index 0000000..12bdf0d --- /dev/null +++ b/python/pyspark/tests/__init__.py @@ -0,0 +1,16 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +#
http://git-wip-us.apache.org/repos/asf/spark/blob/03306a6d/python/pyspark/tests/test_appsubmit.py ---------------------------------------------------------------------- diff --git a/python/pyspark/tests/test_appsubmit.py b/python/pyspark/tests/test_appsubmit.py new file mode 100644 index 0000000..92bcb11 --- /dev/null +++ b/python/pyspark/tests/test_appsubmit.py @@ -0,0 +1,248 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import os +import re +import shutil +import subprocess +import tempfile +import unittest +import zipfile + + +class SparkSubmitTests(unittest.TestCase): + + def setUp(self): + self.programDir = tempfile.mkdtemp() + tmp_dir = tempfile.gettempdir() + self.sparkSubmit = [ + os.path.join(os.environ.get("SPARK_HOME"), "bin", "spark-submit"), + "--conf", "spark.driver.extraJavaOptions=-Djava.io.tmpdir={0}".format(tmp_dir), + "--conf", "spark.executor.extraJavaOptions=-Djava.io.tmpdir={0}".format(tmp_dir), + ] + + def tearDown(self): + shutil.rmtree(self.programDir) + + def createTempFile(self, name, content, dir=None): + """ + Create a temp file with the given name and content and return its path. + Strips leading spaces from content up to the first '|' in each line. + """ + pattern = re.compile(r'^ *\|', re.MULTILINE) + content = re.sub(pattern, '', content.strip()) + if dir is None: + path = os.path.join(self.programDir, name) + else: + os.makedirs(os.path.join(self.programDir, dir)) + path = os.path.join(self.programDir, dir, name) + with open(path, "w") as f: + f.write(content) + return path + + def createFileInZip(self, name, content, ext=".zip", dir=None, zip_name=None): + """ + Create a zip archive containing a file with the given content and return its path. + Strips leading spaces from content up to the first '|' in each line. + """ + pattern = re.compile(r'^ *\|', re.MULTILINE) + content = re.sub(pattern, '', content.strip()) + if dir is None: + path = os.path.join(self.programDir, name + ext) + else: + path = os.path.join(self.programDir, dir, zip_name + ext) + zip = zipfile.ZipFile(path, 'w') + zip.writestr(name, content) + zip.close() + return path + + def create_spark_package(self, artifact_name): + group_id, artifact_id, version = artifact_name.split(":") + self.createTempFile("%s-%s.pom" % (artifact_id, version), (""" + |<?xml version="1.0" encoding="UTF-8"?> + |<project xmlns="http://maven.apache.org/POM/4.0.0" + | xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance" + | xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 + | http://maven.apache.org/xsd/maven-4.0.0.xsd"> + | <modelVersion>4.0.0</modelVersion> + | <groupId>%s</groupId> + | <artifactId>%s</artifactId> + | <version>%s</version> + |</project> + """ % (group_id, artifact_id, version)).lstrip(), + os.path.join(group_id, artifact_id, version)) + self.createFileInZip("%s.py" % artifact_id, """ + |def myfunc(x): + | return x + 1 + """, ".jar", os.path.join(group_id, artifact_id, version), + "%s-%s" % (artifact_id, version)) + + def test_single_script(self): + """Submit and test a single script file""" + script = self.createTempFile("test.py", """ + |from pyspark import SparkContext + | + |sc = SparkContext() + |print(sc.parallelize([1, 2, 3]).map(lambda x: x * 2).collect()) + """) + proc = subprocess.Popen(self.sparkSubmit + [script], stdout=subprocess.PIPE) + out, err = proc.communicate() + self.assertEqual(0, proc.returncode) + self.assertIn("[2, 4, 6]", out.decode('utf-8')) + + def test_script_with_local_functions(self): + """Submit and test a single script file calling a global function""" + script = self.createTempFile("test.py", """ + |from pyspark import SparkContext + | + |def foo(x): + | return x * 3 + | + |sc = SparkContext() + |print(sc.parallelize([1, 2, 3]).map(foo).collect()) + """) + proc = subprocess.Popen(self.sparkSubmit + [script], stdout=subprocess.PIPE) + out, err = proc.communicate() + self.assertEqual(0, proc.returncode) + self.assertIn("[3, 6, 9]", out.decode('utf-8')) + + def test_module_dependency(self): + """Submit and test a script with a dependency on another module""" + script = self.createTempFile("test.py", """ + |from pyspark import SparkContext + |from mylib import myfunc + | + |sc = SparkContext() + |print(sc.parallelize([1, 2, 3]).map(myfunc).collect()) + """) + zip = self.createFileInZip("mylib.py", """ + |def myfunc(x): + | return x + 1 + """) + proc = subprocess.Popen(self.sparkSubmit + ["--py-files", zip, script], + stdout=subprocess.PIPE) + out, err = proc.communicate() + self.assertEqual(0, proc.returncode) + self.assertIn("[2, 3, 4]", out.decode('utf-8')) + + def test_module_dependency_on_cluster(self): + """Submit and test a script with a dependency on another module on a cluster""" + script = self.createTempFile("test.py", """ + |from pyspark import SparkContext + |from mylib import myfunc + | + |sc = SparkContext() + |print(sc.parallelize([1, 2, 3]).map(myfunc).collect()) + """) + zip = self.createFileInZip("mylib.py", """ + |def myfunc(x): + | return x + 1 + """) + proc = subprocess.Popen(self.sparkSubmit + ["--py-files", zip, "--master", + "local-cluster[1,1,1024]", script], + stdout=subprocess.PIPE) + out, err = proc.communicate() + self.assertEqual(0, proc.returncode) + self.assertIn("[2, 3, 4]", out.decode('utf-8')) + + def test_package_dependency(self): + """Submit and test a script with a dependency on a Spark Package""" + script = self.createTempFile("test.py", """ + |from pyspark import SparkContext + |from mylib import myfunc + | + |sc = SparkContext() + |print(sc.parallelize([1, 2, 3]).map(myfunc).collect()) + """) + self.create_spark_package("a:mylib:0.1") + proc = subprocess.Popen( + self.sparkSubmit + ["--packages", "a:mylib:0.1", "--repositories", + "file:" + self.programDir, script], + stdout=subprocess.PIPE) + out, err = proc.communicate() + self.assertEqual(0, proc.returncode) + self.assertIn("[2, 3, 4]", out.decode('utf-8')) + + def test_package_dependency_on_cluster(self): + """Submit and test a script with a dependency on a Spark Package on a cluster""" + script = self.createTempFile("test.py", """ + |from pyspark import SparkContext + |from mylib import myfunc + | + |sc = SparkContext() + |print(sc.parallelize([1, 2, 3]).map(myfunc).collect()) + """) + self.create_spark_package("a:mylib:0.1") + proc = subprocess.Popen( + self.sparkSubmit + ["--packages", "a:mylib:0.1", "--repositories", + "file:" + self.programDir, "--master", "local-cluster[1,1,1024]", + script], + stdout=subprocess.PIPE) + out, err = proc.communicate() + self.assertEqual(0, proc.returncode) + self.assertIn("[2, 3, 4]", out.decode('utf-8')) + + def test_single_script_on_cluster(self): + """Submit and test a single script on a cluster""" + script = self.createTempFile("test.py", """ + |from pyspark import SparkContext + | + |def foo(x): + | return x * 2 + | + |sc = SparkContext() + |print(sc.parallelize([1, 2, 3]).map(foo).collect()) + """) + # this will fail if you have different spark.executor.memory + # in conf/spark-defaults.conf + proc = subprocess.Popen( + self.sparkSubmit + ["--master", "local-cluster[1,1,1024]", script], + stdout=subprocess.PIPE) + out, err = proc.communicate() + self.assertEqual(0, proc.returncode) + self.assertIn("[2, 4, 6]", out.decode('utf-8')) + + def test_user_configuration(self): + """Make sure user configuration is respected (SPARK-19307)""" + script = self.createTempFile("test.py", """ + |from pyspark import SparkConf, SparkContext + | + |conf = SparkConf().set("spark.test_config", "1") + |sc = SparkContext(conf = conf) + |try: + | if sc._conf.get("spark.test_config") != "1": + | raise Exception("Cannot find spark.test_config in SparkContext's conf.") + |finally: + | sc.stop() + """) + proc = subprocess.Popen( + self.sparkSubmit + ["--master", "local", script], + stdout=subprocess.PIPE, + stderr=subprocess.STDOUT) + out, err = proc.communicate() + self.assertEqual(0, proc.returncode, msg="Process failed with error:\n {0}".format(out)) + + +if __name__ == "__main__": + from pyspark.tests.test_appsubmit import * + + try: + import xmlrunner + testRunner = xmlrunner.XMLTestRunner(output='target/test-reports') + except ImportError: + testRunner = None + unittest.main(testRunner=testRunner, verbosity=2) http://git-wip-us.apache.org/repos/asf/spark/blob/03306a6d/python/pyspark/tests/test_broadcast.py ---------------------------------------------------------------------- diff --git a/python/pyspark/tests/test_broadcast.py b/python/pyspark/tests/test_broadcast.py new file mode 100644 index 0000000..a98626e --- /dev/null +++ b/python/pyspark/tests/test_broadcast.py @@ -0,0 +1,122 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +import os +import random +import tempfile +import unittest + +from pyspark import SparkConf, SparkContext +from pyspark.java_gateway import launch_gateway +from pyspark.serializers import ChunkedStream + + +class BroadcastTest(unittest.TestCase): + + def tearDown(self): + if getattr(self, "sc", None) is not None: + self.sc.stop() + self.sc = None + + def _test_encryption_helper(self, vs): + """ + Creates a broadcast variables for each value in vs, and runs a simple job to make sure the + value is the same when it's read in the executors. Also makes sure there are no task + failures. + """ + bs = [self.sc.broadcast(value=v) for v in vs] + exec_values = self.sc.parallelize(range(2)).map(lambda x: [b.value for b in bs]).collect() + for ev in exec_values: + self.assertEqual(ev, vs) + # make sure there are no task failures + status = self.sc.statusTracker() + for jid in status.getJobIdsForGroup(): + for sid in status.getJobInfo(jid).stageIds: + stage_info = status.getStageInfo(sid) + self.assertEqual(0, stage_info.numFailedTasks) + + def _test_multiple_broadcasts(self, *extra_confs): + """ + Test broadcast variables make it OK to the executors. Tests multiple broadcast variables, + and also multiple jobs. + """ + conf = SparkConf() + for key, value in extra_confs: + conf.set(key, value) + conf.setMaster("local-cluster[2,1,1024]") + self.sc = SparkContext(conf=conf) + self._test_encryption_helper([5]) + self._test_encryption_helper([5, 10, 20]) + + def test_broadcast_with_encryption(self): + self._test_multiple_broadcasts(("spark.io.encryption.enabled", "true")) + + def test_broadcast_no_encryption(self): + self._test_multiple_broadcasts() + + +class BroadcastFrameProtocolTest(unittest.TestCase): + + @classmethod + def setUpClass(cls): + gateway = launch_gateway(SparkConf()) + cls._jvm = gateway.jvm + cls.longMessage = True + random.seed(42) + + def _test_chunked_stream(self, data, py_buf_size): + # write data using the chunked protocol from python. + chunked_file = tempfile.NamedTemporaryFile(delete=False) + dechunked_file = tempfile.NamedTemporaryFile(delete=False) + dechunked_file.close() + try: + out = ChunkedStream(chunked_file, py_buf_size) + out.write(data) + out.close() + # now try to read it in java + jin = self._jvm.java.io.FileInputStream(chunked_file.name) + jout = self._jvm.java.io.FileOutputStream(dechunked_file.name) + self._jvm.DechunkedInputStream.dechunkAndCopyToOutput(jin, jout) + # java should have decoded it back to the original data + self.assertEqual(len(data), os.stat(dechunked_file.name).st_size) + with open(dechunked_file.name, "rb") as f: + byte = f.read(1) + idx = 0 + while byte: + self.assertEqual(data[idx], bytearray(byte)[0], msg="idx = " + str(idx)) + byte = f.read(1) + idx += 1 + finally: + os.unlink(chunked_file.name) + os.unlink(dechunked_file.name) + + def test_chunked_stream(self): + def random_bytes(n): + return bytearray(random.getrandbits(8) for _ in range(n)) + for data_length in [1, 10, 100, 10000]: + for buffer_length in [1, 2, 5, 8192]: + self._test_chunked_stream(random_bytes(data_length), buffer_length) + + +if __name__ == '__main__': + from pyspark.tests.test_broadcast import * + + try: + import xmlrunner + testRunner = xmlrunner.XMLTestRunner(output='target/test-reports') + except ImportError: + testRunner = None + unittest.main(testRunner=testRunner, verbosity=2) http://git-wip-us.apache.org/repos/asf/spark/blob/03306a6d/python/pyspark/tests/test_conf.py ---------------------------------------------------------------------- diff --git a/python/pyspark/tests/test_conf.py b/python/pyspark/tests/test_conf.py new file mode 100644 index 0000000..f5a9acc --- /dev/null +++ b/python/pyspark/tests/test_conf.py @@ -0,0 +1,43 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +import random +import unittest + +from pyspark import SparkContext, SparkConf + + +class ConfTests(unittest.TestCase): + def test_memory_conf(self): + memoryList = ["1T", "1G", "1M", "1024K"] + for memory in memoryList: + sc = SparkContext(conf=SparkConf().set("spark.python.worker.memory", memory)) + l = list(range(1024)) + random.shuffle(l) + rdd = sc.parallelize(l, 4) + self.assertEqual(sorted(l), rdd.sortBy(lambda x: x).collect()) + sc.stop() + + +if __name__ == "__main__": + from pyspark.tests.test_conf import * + + try: + import xmlrunner + testRunner = xmlrunner.XMLTestRunner(output='target/test-reports') + except ImportError: + testRunner = None + unittest.main(testRunner=testRunner, verbosity=2) http://git-wip-us.apache.org/repos/asf/spark/blob/03306a6d/python/pyspark/tests/test_context.py ---------------------------------------------------------------------- diff --git a/python/pyspark/tests/test_context.py b/python/pyspark/tests/test_context.py new file mode 100644 index 0000000..201baf4 --- /dev/null +++ b/python/pyspark/tests/test_context.py @@ -0,0 +1,258 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +import os +import shutil +import tempfile +import threading +import time +import unittest + +from pyspark import SparkFiles, SparkContext +from pyspark.testing.utils import ReusedPySparkTestCase, PySparkTestCase, QuietTest, SPARK_HOME + + +class CheckpointTests(ReusedPySparkTestCase): + + def setUp(self): + self.checkpointDir = tempfile.NamedTemporaryFile(delete=False) + os.unlink(self.checkpointDir.name) + self.sc.setCheckpointDir(self.checkpointDir.name) + + def tearDown(self): + shutil.rmtree(self.checkpointDir.name) + + def test_basic_checkpointing(self): + parCollection = self.sc.parallelize([1, 2, 3, 4]) + flatMappedRDD = parCollection.flatMap(lambda x: range(1, x + 1)) + + self.assertFalse(flatMappedRDD.isCheckpointed()) + self.assertTrue(flatMappedRDD.getCheckpointFile() is None) + + flatMappedRDD.checkpoint() + result = flatMappedRDD.collect() + time.sleep(1) # 1 second + self.assertTrue(flatMappedRDD.isCheckpointed()) + self.assertEqual(flatMappedRDD.collect(), result) + self.assertEqual("file:" + self.checkpointDir.name, + os.path.dirname(os.path.dirname(flatMappedRDD.getCheckpointFile()))) + + def test_checkpoint_and_restore(self): + parCollection = self.sc.parallelize([1, 2, 3, 4]) + flatMappedRDD = parCollection.flatMap(lambda x: [x]) + + self.assertFalse(flatMappedRDD.isCheckpointed()) + self.assertTrue(flatMappedRDD.getCheckpointFile() is None) + + flatMappedRDD.checkpoint() + flatMappedRDD.count() # forces a checkpoint to be computed + time.sleep(1) # 1 second + + self.assertTrue(flatMappedRDD.getCheckpointFile() is not None) + recovered = self.sc._checkpointFile(flatMappedRDD.getCheckpointFile(), + flatMappedRDD._jrdd_deserializer) + self.assertEqual([1, 2, 3, 4], recovered.collect()) + + +class LocalCheckpointTests(ReusedPySparkTestCase): + + def test_basic_localcheckpointing(self): + parCollection = self.sc.parallelize([1, 2, 3, 4]) + flatMappedRDD = parCollection.flatMap(lambda x: range(1, x + 1)) + + self.assertFalse(flatMappedRDD.isCheckpointed()) + self.assertFalse(flatMappedRDD.isLocallyCheckpointed()) + + flatMappedRDD.localCheckpoint() + result = flatMappedRDD.collect() + time.sleep(1) # 1 second + self.assertTrue(flatMappedRDD.isCheckpointed()) + self.assertTrue(flatMappedRDD.isLocallyCheckpointed()) + self.assertEqual(flatMappedRDD.collect(), result) + + +class AddFileTests(PySparkTestCase): + + def test_add_py_file(self): + # To ensure that we're actually testing addPyFile's effects, check that + # this job fails due to `userlibrary` not being on the Python path: + # disable logging in log4j temporarily + def func(x): + from userlibrary import UserClass + return UserClass().hello() + with QuietTest(self.sc): + self.assertRaises(Exception, self.sc.parallelize(range(2)).map(func).first) + + # Add the file, so the job should now succeed: + path = os.path.join(SPARK_HOME, "python/test_support/userlibrary.py") + self.sc.addPyFile(path) + res = self.sc.parallelize(range(2)).map(func).first() + self.assertEqual("Hello World!", res) + + def test_add_file_locally(self): + path = os.path.join(SPARK_HOME, "python/test_support/hello/hello.txt") + self.sc.addFile(path) + download_path = SparkFiles.get("hello.txt") + self.assertNotEqual(path, download_path) + with open(download_path) as test_file: + self.assertEqual("Hello World!\n", test_file.readline()) + + def test_add_file_recursively_locally(self): + path = os.path.join(SPARK_HOME, "python/test_support/hello") + self.sc.addFile(path, True) + download_path = SparkFiles.get("hello") + self.assertNotEqual(path, download_path) + with open(download_path + "/hello.txt") as test_file: + self.assertEqual("Hello World!\n", test_file.readline()) + with open(download_path + "/sub_hello/sub_hello.txt") as test_file: + self.assertEqual("Sub Hello World!\n", test_file.readline()) + + def test_add_py_file_locally(self): + # To ensure that we're actually testing addPyFile's effects, check that + # this fails due to `userlibrary` not being on the Python path: + def func(): + from userlibrary import UserClass + self.assertRaises(ImportError, func) + path = os.path.join(SPARK_HOME, "python/test_support/userlibrary.py") + self.sc.addPyFile(path) + from userlibrary import UserClass + self.assertEqual("Hello World!", UserClass().hello()) + + def test_add_egg_file_locally(self): + # To ensure that we're actually testing addPyFile's effects, check that + # this fails due to `userlibrary` not being on the Python path: + def func(): + from userlib import UserClass + self.assertRaises(ImportError, func) + path = os.path.join(SPARK_HOME, "python/test_support/userlib-0.1.zip") + self.sc.addPyFile(path) + from userlib import UserClass + self.assertEqual("Hello World from inside a package!", UserClass().hello()) + + def test_overwrite_system_module(self): + self.sc.addPyFile(os.path.join(SPARK_HOME, "python/test_support/SimpleHTTPServer.py")) + + import SimpleHTTPServer + self.assertEqual("My Server", SimpleHTTPServer.__name__) + + def func(x): + import SimpleHTTPServer + return SimpleHTTPServer.__name__ + + self.assertEqual(["My Server"], self.sc.parallelize(range(1)).map(func).collect()) + + +class ContextTests(unittest.TestCase): + + def test_failed_sparkcontext_creation(self): + # Regression test for SPARK-1550 + self.assertRaises(Exception, lambda: SparkContext("an-invalid-master-name")) + + def test_get_or_create(self): + with SparkContext.getOrCreate() as sc: + self.assertTrue(SparkContext.getOrCreate() is sc) + + def test_parallelize_eager_cleanup(self): + with SparkContext() as sc: + temp_files = os.listdir(sc._temp_dir) + rdd = sc.parallelize([0, 1, 2]) + post_parallalize_temp_files = os.listdir(sc._temp_dir) + self.assertEqual(temp_files, post_parallalize_temp_files) + + def test_set_conf(self): + # This is for an internal use case. When there is an existing SparkContext, + # SparkSession's builder needs to set configs into SparkContext's conf. + sc = SparkContext() + sc._conf.set("spark.test.SPARK16224", "SPARK16224") + self.assertEqual(sc._jsc.sc().conf().get("spark.test.SPARK16224"), "SPARK16224") + sc.stop() + + def test_stop(self): + sc = SparkContext() + self.assertNotEqual(SparkContext._active_spark_context, None) + sc.stop() + self.assertEqual(SparkContext._active_spark_context, None) + + def test_with(self): + with SparkContext() as sc: + self.assertNotEqual(SparkContext._active_spark_context, None) + self.assertEqual(SparkContext._active_spark_context, None) + + def test_with_exception(self): + try: + with SparkContext() as sc: + self.assertNotEqual(SparkContext._active_spark_context, None) + raise Exception() + except: + pass + self.assertEqual(SparkContext._active_spark_context, None) + + def test_with_stop(self): + with SparkContext() as sc: + self.assertNotEqual(SparkContext._active_spark_context, None) + sc.stop() + self.assertEqual(SparkContext._active_spark_context, None) + + def test_progress_api(self): + with SparkContext() as sc: + sc.setJobGroup('test_progress_api', '', True) + rdd = sc.parallelize(range(10)).map(lambda x: time.sleep(100)) + + def run(): + try: + rdd.count() + except Exception: + pass + t = threading.Thread(target=run) + t.daemon = True + t.start() + # wait for scheduler to start + time.sleep(1) + + tracker = sc.statusTracker() + jobIds = tracker.getJobIdsForGroup('test_progress_api') + self.assertEqual(1, len(jobIds)) + job = tracker.getJobInfo(jobIds[0]) + self.assertEqual(1, len(job.stageIds)) + stage = tracker.getStageInfo(job.stageIds[0]) + self.assertEqual(rdd.getNumPartitions(), stage.numTasks) + + sc.cancelAllJobs() + t.join() + # wait for event listener to update the status + time.sleep(1) + + job = tracker.getJobInfo(jobIds[0]) + self.assertEqual('FAILED', job.status) + self.assertEqual([], tracker.getActiveJobsIds()) + self.assertEqual([], tracker.getActiveStageIds()) + + sc.stop() + + def test_startTime(self): + with SparkContext() as sc: + self.assertGreater(sc.startTime, 0) + + +if __name__ == "__main__": + from pyspark.tests.test_context import * + + try: + import xmlrunner + testRunner = xmlrunner.XMLTestRunner(output='target/test-reports') + except ImportError: + testRunner = None + unittest.main(testRunner=testRunner, verbosity=2) http://git-wip-us.apache.org/repos/asf/spark/blob/03306a6d/python/pyspark/tests/test_daemon.py ---------------------------------------------------------------------- diff --git a/python/pyspark/tests/test_daemon.py b/python/pyspark/tests/test_daemon.py new file mode 100644 index 0000000..fccd74f --- /dev/null +++ b/python/pyspark/tests/test_daemon.py @@ -0,0 +1,80 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +import os +import sys +import time +import unittest + +from pyspark.serializers import read_int + + +class DaemonTests(unittest.TestCase): + def connect(self, port): + from socket import socket, AF_INET, SOCK_STREAM + sock = socket(AF_INET, SOCK_STREAM) + sock.connect(('127.0.0.1', port)) + # send a split index of -1 to shutdown the worker + sock.send(b"\xFF\xFF\xFF\xFF") + sock.close() + return True + + def do_termination_test(self, terminator): + from subprocess import Popen, PIPE + from errno import ECONNREFUSED + + # start daemon + daemon_path = os.path.join(os.path.dirname(__file__), "..", "daemon.py") + python_exec = sys.executable or os.environ.get("PYSPARK_PYTHON") + daemon = Popen([python_exec, daemon_path], stdin=PIPE, stdout=PIPE) + + # read the port number + port = read_int(daemon.stdout) + + # daemon should accept connections + self.assertTrue(self.connect(port)) + + # request shutdown + terminator(daemon) + time.sleep(1) + + # daemon should no longer accept connections + try: + self.connect(port) + except EnvironmentError as exception: + self.assertEqual(exception.errno, ECONNREFUSED) + else: + self.fail("Expected EnvironmentError to be raised") + + def test_termination_stdin(self): + """Ensure that daemon and workers terminate when stdin is closed.""" + self.do_termination_test(lambda daemon: daemon.stdin.close()) + + def test_termination_sigterm(self): + """Ensure that daemon and workers terminate on SIGTERM.""" + from signal import SIGTERM + self.do_termination_test(lambda daemon: os.kill(daemon.pid, SIGTERM)) + + +if __name__ == "__main__": + from pyspark.tests.test_daemon import * + + try: + import xmlrunner + testRunner = xmlrunner.XMLTestRunner(output='target/test-reports') + except ImportError: + testRunner = None + unittest.main(testRunner=testRunner, verbosity=2) http://git-wip-us.apache.org/repos/asf/spark/blob/03306a6d/python/pyspark/tests/test_join.py ---------------------------------------------------------------------- diff --git a/python/pyspark/tests/test_join.py b/python/pyspark/tests/test_join.py new file mode 100644 index 0000000..e97e695 --- /dev/null +++ b/python/pyspark/tests/test_join.py @@ -0,0 +1,69 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +from pyspark.testing.utils import ReusedPySparkTestCase + + +class JoinTests(ReusedPySparkTestCase): + + def test_narrow_dependency_in_join(self): + rdd = self.sc.parallelize(range(10)).map(lambda x: (x, x)) + parted = rdd.partitionBy(2) + self.assertEqual(2, parted.union(parted).getNumPartitions()) + self.assertEqual(rdd.getNumPartitions() + 2, parted.union(rdd).getNumPartitions()) + self.assertEqual(rdd.getNumPartitions() + 2, rdd.union(parted).getNumPartitions()) + + tracker = self.sc.statusTracker() + + self.sc.setJobGroup("test1", "test", True) + d = sorted(parted.join(parted).collect()) + self.assertEqual(10, len(d)) + self.assertEqual((0, (0, 0)), d[0]) + jobId = tracker.getJobIdsForGroup("test1")[0] + self.assertEqual(2, len(tracker.getJobInfo(jobId).stageIds)) + + self.sc.setJobGroup("test2", "test", True) + d = sorted(parted.join(rdd).collect()) + self.assertEqual(10, len(d)) + self.assertEqual((0, (0, 0)), d[0]) + jobId = tracker.getJobIdsForGroup("test2")[0] + self.assertEqual(3, len(tracker.getJobInfo(jobId).stageIds)) + + self.sc.setJobGroup("test3", "test", True) + d = sorted(parted.cogroup(parted).collect()) + self.assertEqual(10, len(d)) + self.assertEqual([[0], [0]], list(map(list, d[0][1]))) + jobId = tracker.getJobIdsForGroup("test3")[0] + self.assertEqual(2, len(tracker.getJobInfo(jobId).stageIds)) + + self.sc.setJobGroup("test4", "test", True) + d = sorted(parted.cogroup(rdd).collect()) + self.assertEqual(10, len(d)) + self.assertEqual([[0], [0]], list(map(list, d[0][1]))) + jobId = tracker.getJobIdsForGroup("test4")[0] + self.assertEqual(3, len(tracker.getJobInfo(jobId).stageIds)) + + +if __name__ == "__main__": + import unittest + from pyspark.tests.test_join import * + + try: + import xmlrunner + testRunner = xmlrunner.XMLTestRunner(output='target/test-reports') + except ImportError: + testRunner = None + unittest.main(testRunner=testRunner, verbosity=2) http://git-wip-us.apache.org/repos/asf/spark/blob/03306a6d/python/pyspark/tests/test_profiler.py ---------------------------------------------------------------------- diff --git a/python/pyspark/tests/test_profiler.py b/python/pyspark/tests/test_profiler.py new file mode 100644 index 0000000..56cbcff --- /dev/null +++ b/python/pyspark/tests/test_profiler.py @@ -0,0 +1,112 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import os +import sys +import tempfile +import unittest + +from pyspark import SparkConf, SparkContext, BasicProfiler +from pyspark.testing.utils import PySparkTestCase + +if sys.version >= "3": + from io import StringIO +else: + from StringIO import StringIO + + +class ProfilerTests(PySparkTestCase): + + def setUp(self): + self._old_sys_path = list(sys.path) + class_name = self.__class__.__name__ + conf = SparkConf().set("spark.python.profile", "true") + self.sc = SparkContext('local[4]', class_name, conf=conf) + + def test_profiler(self): + self.do_computation() + + profilers = self.sc.profiler_collector.profilers + self.assertEqual(1, len(profilers)) + id, profiler, _ = profilers[0] + stats = profiler.stats() + self.assertTrue(stats is not None) + width, stat_list = stats.get_print_list([]) + func_names = [func_name for fname, n, func_name in stat_list] + self.assertTrue("heavy_foo" in func_names) + + old_stdout = sys.stdout + sys.stdout = io = StringIO() + self.sc.show_profiles() + self.assertTrue("heavy_foo" in io.getvalue()) + sys.stdout = old_stdout + + d = tempfile.gettempdir() + self.sc.dump_profiles(d) + self.assertTrue("rdd_%d.pstats" % id in os.listdir(d)) + + def test_custom_profiler(self): + class TestCustomProfiler(BasicProfiler): + def show(self, id): + self.result = "Custom formatting" + + self.sc.profiler_collector.profiler_cls = TestCustomProfiler + + self.do_computation() + + profilers = self.sc.profiler_collector.profilers + self.assertEqual(1, len(profilers)) + _, profiler, _ = profilers[0] + self.assertTrue(isinstance(profiler, TestCustomProfiler)) + + self.sc.show_profiles() + self.assertEqual("Custom formatting", profiler.result) + + def do_computation(self): + def heavy_foo(x): + for i in range(1 << 18): + x = 1 + + rdd = self.sc.parallelize(range(100)) + rdd.foreach(heavy_foo) + + +class ProfilerTests2(unittest.TestCase): + def test_profiler_disabled(self): + sc = SparkContext(conf=SparkConf().set("spark.python.profile", "false")) + try: + self.assertRaisesRegexp( + RuntimeError, + "'spark.python.profile' configuration must be set", + lambda: sc.show_profiles()) + self.assertRaisesRegexp( + RuntimeError, + "'spark.python.profile' configuration must be set", + lambda: sc.dump_profiles("/tmp/abc")) + finally: + sc.stop() + + +if __name__ == "__main__": + from pyspark.tests.test_profiler import * + + try: + import xmlrunner + testRunner = xmlrunner.XMLTestRunner(output='target/test-reports') + except ImportError: + testRunner = None + unittest.main(testRunner=testRunner, verbosity=2) http://git-wip-us.apache.org/repos/asf/spark/blob/03306a6d/python/pyspark/tests/test_rdd.py ---------------------------------------------------------------------- diff --git a/python/pyspark/tests/test_rdd.py b/python/pyspark/tests/test_rdd.py new file mode 100644 index 0000000..b2a544b --- /dev/null +++ b/python/pyspark/tests/test_rdd.py @@ -0,0 +1,739 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +import hashlib +import os +import random +import sys +import tempfile +from glob import glob + +from py4j.protocol import Py4JJavaError + +from pyspark import shuffle, RDD +from pyspark.serializers import CloudPickleSerializer, BatchedSerializer, PickleSerializer,\ + MarshalSerializer, UTF8Deserializer, NoOpSerializer +from pyspark.testing.utils import ReusedPySparkTestCase, SPARK_HOME, QuietTest + +if sys.version_info[0] >= 3: + xrange = range + + +class RDDTests(ReusedPySparkTestCase): + + def test_range(self): + self.assertEqual(self.sc.range(1, 1).count(), 0) + self.assertEqual(self.sc.range(1, 0, -1).count(), 1) + self.assertEqual(self.sc.range(0, 1 << 40, 1 << 39).count(), 2) + + def test_id(self): + rdd = self.sc.parallelize(range(10)) + id = rdd.id() + self.assertEqual(id, rdd.id()) + rdd2 = rdd.map(str).filter(bool) + id2 = rdd2.id() + self.assertEqual(id + 1, id2) + self.assertEqual(id2, rdd2.id()) + + def test_empty_rdd(self): + rdd = self.sc.emptyRDD() + self.assertTrue(rdd.isEmpty()) + + def test_sum(self): + self.assertEqual(0, self.sc.emptyRDD().sum()) + self.assertEqual(6, self.sc.parallelize([1, 2, 3]).sum()) + + def test_to_localiterator(self): + from time import sleep + rdd = self.sc.parallelize([1, 2, 3]) + it = rdd.toLocalIterator() + sleep(5) + self.assertEqual([1, 2, 3], sorted(it)) + + rdd2 = rdd.repartition(1000) + it2 = rdd2.toLocalIterator() + sleep(5) + self.assertEqual([1, 2, 3], sorted(it2)) + + def test_save_as_textfile_with_unicode(self): + # Regression test for SPARK-970 + x = u"\u00A1Hola, mundo!" + data = self.sc.parallelize([x]) + tempFile = tempfile.NamedTemporaryFile(delete=True) + tempFile.close() + data.saveAsTextFile(tempFile.name) + raw_contents = b''.join(open(p, 'rb').read() + for p in glob(tempFile.name + "/part-0000*")) + self.assertEqual(x, raw_contents.strip().decode("utf-8")) + + def test_save_as_textfile_with_utf8(self): + x = u"\u00A1Hola, mundo!" + data = self.sc.parallelize([x.encode("utf-8")]) + tempFile = tempfile.NamedTemporaryFile(delete=True) + tempFile.close() + data.saveAsTextFile(tempFile.name) + raw_contents = b''.join(open(p, 'rb').read() + for p in glob(tempFile.name + "/part-0000*")) + self.assertEqual(x, raw_contents.strip().decode('utf8')) + + def test_transforming_cartesian_result(self): + # Regression test for SPARK-1034 + rdd1 = self.sc.parallelize([1, 2]) + rdd2 = self.sc.parallelize([3, 4]) + cart = rdd1.cartesian(rdd2) + result = cart.map(lambda x_y3: x_y3[0] + x_y3[1]).collect() + + def test_transforming_pickle_file(self): + # Regression test for SPARK-2601 + data = self.sc.parallelize([u"Hello", u"World!"]) + tempFile = tempfile.NamedTemporaryFile(delete=True) + tempFile.close() + data.saveAsPickleFile(tempFile.name) + pickled_file = self.sc.pickleFile(tempFile.name) + pickled_file.map(lambda x: x).collect() + + def test_cartesian_on_textfile(self): + # Regression test for + path = os.path.join(SPARK_HOME, "python/test_support/hello/hello.txt") + a = self.sc.textFile(path) + result = a.cartesian(a).collect() + (x, y) = result[0] + self.assertEqual(u"Hello World!", x.strip()) + self.assertEqual(u"Hello World!", y.strip()) + + def test_cartesian_chaining(self): + # Tests for SPARK-16589 + rdd = self.sc.parallelize(range(10), 2) + self.assertSetEqual( + set(rdd.cartesian(rdd).cartesian(rdd).collect()), + set([((x, y), z) for x in range(10) for y in range(10) for z in range(10)]) + ) + + self.assertSetEqual( + set(rdd.cartesian(rdd.cartesian(rdd)).collect()), + set([(x, (y, z)) for x in range(10) for y in range(10) for z in range(10)]) + ) + + self.assertSetEqual( + set(rdd.cartesian(rdd.zip(rdd)).collect()), + set([(x, (y, y)) for x in range(10) for y in range(10)]) + ) + + def test_zip_chaining(self): + # Tests for SPARK-21985 + rdd = self.sc.parallelize('abc', 2) + self.assertSetEqual( + set(rdd.zip(rdd).zip(rdd).collect()), + set([((x, x), x) for x in 'abc']) + ) + self.assertSetEqual( + set(rdd.zip(rdd.zip(rdd)).collect()), + set([(x, (x, x)) for x in 'abc']) + ) + + def test_deleting_input_files(self): + # Regression test for SPARK-1025 + tempFile = tempfile.NamedTemporaryFile(delete=False) + tempFile.write(b"Hello World!") + tempFile.close() + data = self.sc.textFile(tempFile.name) + filtered_data = data.filter(lambda x: True) + self.assertEqual(1, filtered_data.count()) + os.unlink(tempFile.name) + with QuietTest(self.sc): + self.assertRaises(Exception, lambda: filtered_data.count()) + + def test_sampling_default_seed(self): + # Test for SPARK-3995 (default seed setting) + data = self.sc.parallelize(xrange(1000), 1) + subset = data.takeSample(False, 10) + self.assertEqual(len(subset), 10) + + def test_aggregate_mutable_zero_value(self): + # Test for SPARK-9021; uses aggregate and treeAggregate to build dict + # representing a counter of ints + # NOTE: dict is used instead of collections.Counter for Python 2.6 + # compatibility + from collections import defaultdict + + # Show that single or multiple partitions work + data1 = self.sc.range(10, numSlices=1) + data2 = self.sc.range(10, numSlices=2) + + def seqOp(x, y): + x[y] += 1 + return x + + def comboOp(x, y): + for key, val in y.items(): + x[key] += val + return x + + counts1 = data1.aggregate(defaultdict(int), seqOp, comboOp) + counts2 = data2.aggregate(defaultdict(int), seqOp, comboOp) + counts3 = data1.treeAggregate(defaultdict(int), seqOp, comboOp, 2) + counts4 = data2.treeAggregate(defaultdict(int), seqOp, comboOp, 2) + + ground_truth = defaultdict(int, dict((i, 1) for i in range(10))) + self.assertEqual(counts1, ground_truth) + self.assertEqual(counts2, ground_truth) + self.assertEqual(counts3, ground_truth) + self.assertEqual(counts4, ground_truth) + + def test_aggregate_by_key_mutable_zero_value(self): + # Test for SPARK-9021; uses aggregateByKey to make a pair RDD that + # contains lists of all values for each key in the original RDD + + # list(range(...)) for Python 3.x compatibility (can't use * operator + # on a range object) + # list(zip(...)) for Python 3.x compatibility (want to parallelize a + # collection, not a zip object) + tuples = list(zip(list(range(10))*2, [1]*20)) + # Show that single or multiple partitions work + data1 = self.sc.parallelize(tuples, 1) + data2 = self.sc.parallelize(tuples, 2) + + def seqOp(x, y): + x.append(y) + return x + + def comboOp(x, y): + x.extend(y) + return x + + values1 = data1.aggregateByKey([], seqOp, comboOp).collect() + values2 = data2.aggregateByKey([], seqOp, comboOp).collect() + # Sort lists to ensure clean comparison with ground_truth + values1.sort() + values2.sort() + + ground_truth = [(i, [1]*2) for i in range(10)] + self.assertEqual(values1, ground_truth) + self.assertEqual(values2, ground_truth) + + def test_fold_mutable_zero_value(self): + # Test for SPARK-9021; uses fold to merge an RDD of dict counters into + # a single dict + # NOTE: dict is used instead of collections.Counter for Python 2.6 + # compatibility + from collections import defaultdict + + counts1 = defaultdict(int, dict((i, 1) for i in range(10))) + counts2 = defaultdict(int, dict((i, 1) for i in range(3, 8))) + counts3 = defaultdict(int, dict((i, 1) for i in range(4, 7))) + counts4 = defaultdict(int, dict((i, 1) for i in range(5, 6))) + all_counts = [counts1, counts2, counts3, counts4] + # Show that single or multiple partitions work + data1 = self.sc.parallelize(all_counts, 1) + data2 = self.sc.parallelize(all_counts, 2) + + def comboOp(x, y): + for key, val in y.items(): + x[key] += val + return x + + fold1 = data1.fold(defaultdict(int), comboOp) + fold2 = data2.fold(defaultdict(int), comboOp) + + ground_truth = defaultdict(int) + for counts in all_counts: + for key, val in counts.items(): + ground_truth[key] += val + self.assertEqual(fold1, ground_truth) + self.assertEqual(fold2, ground_truth) + + def test_fold_by_key_mutable_zero_value(self): + # Test for SPARK-9021; uses foldByKey to make a pair RDD that contains + # lists of all values for each key in the original RDD + + tuples = [(i, range(i)) for i in range(10)]*2 + # Show that single or multiple partitions work + data1 = self.sc.parallelize(tuples, 1) + data2 = self.sc.parallelize(tuples, 2) + + def comboOp(x, y): + x.extend(y) + return x + + values1 = data1.foldByKey([], comboOp).collect() + values2 = data2.foldByKey([], comboOp).collect() + # Sort lists to ensure clean comparison with ground_truth + values1.sort() + values2.sort() + + # list(range(...)) for Python 3.x compatibility + ground_truth = [(i, list(range(i))*2) for i in range(10)] + self.assertEqual(values1, ground_truth) + self.assertEqual(values2, ground_truth) + + def test_aggregate_by_key(self): + data = self.sc.parallelize([(1, 1), (1, 1), (3, 2), (5, 1), (5, 3)], 2) + + def seqOp(x, y): + x.add(y) + return x + + def combOp(x, y): + x |= y + return x + + sets = dict(data.aggregateByKey(set(), seqOp, combOp).collect()) + self.assertEqual(3, len(sets)) + self.assertEqual(set([1]), sets[1]) + self.assertEqual(set([2]), sets[3]) + self.assertEqual(set([1, 3]), sets[5]) + + def test_itemgetter(self): + rdd = self.sc.parallelize([range(10)]) + from operator import itemgetter + self.assertEqual([1], rdd.map(itemgetter(1)).collect()) + self.assertEqual([(2, 3)], rdd.map(itemgetter(2, 3)).collect()) + + def test_namedtuple_in_rdd(self): + from collections import namedtuple + Person = namedtuple("Person", "id firstName lastName") + jon = Person(1, "Jon", "Doe") + jane = Person(2, "Jane", "Doe") + theDoes = self.sc.parallelize([jon, jane]) + self.assertEqual([jon, jane], theDoes.collect()) + + def test_large_broadcast(self): + N = 10000 + data = [[float(i) for i in range(300)] for i in range(N)] + bdata = self.sc.broadcast(data) # 27MB + m = self.sc.parallelize(range(1), 1).map(lambda x: len(bdata.value)).sum() + self.assertEqual(N, m) + + def test_unpersist(self): + N = 1000 + data = [[float(i) for i in range(300)] for i in range(N)] + bdata = self.sc.broadcast(data) # 3MB + bdata.unpersist() + m = self.sc.parallelize(range(1), 1).map(lambda x: len(bdata.value)).sum() + self.assertEqual(N, m) + bdata.destroy() + try: + self.sc.parallelize(range(1), 1).map(lambda x: len(bdata.value)).sum() + except Exception as e: + pass + else: + raise Exception("job should fail after destroy the broadcast") + + def test_multiple_broadcasts(self): + N = 1 << 21 + b1 = self.sc.broadcast(set(range(N))) # multiple blocks in JVM + r = list(range(1 << 15)) + random.shuffle(r) + s = str(r).encode() + checksum = hashlib.md5(s).hexdigest() + b2 = self.sc.broadcast(s) + r = list(set(self.sc.parallelize(range(10), 10).map( + lambda x: (len(b1.value), hashlib.md5(b2.value).hexdigest())).collect())) + self.assertEqual(1, len(r)) + size, csum = r[0] + self.assertEqual(N, size) + self.assertEqual(checksum, csum) + + random.shuffle(r) + s = str(r).encode() + checksum = hashlib.md5(s).hexdigest() + b2 = self.sc.broadcast(s) + r = list(set(self.sc.parallelize(range(10), 10).map( + lambda x: (len(b1.value), hashlib.md5(b2.value).hexdigest())).collect())) + self.assertEqual(1, len(r)) + size, csum = r[0] + self.assertEqual(N, size) + self.assertEqual(checksum, csum) + + def test_multithread_broadcast_pickle(self): + import threading + + b1 = self.sc.broadcast(list(range(3))) + b2 = self.sc.broadcast(list(range(3))) + + def f1(): + return b1.value + + def f2(): + return b2.value + + funcs_num_pickled = {f1: None, f2: None} + + def do_pickle(f, sc): + command = (f, None, sc.serializer, sc.serializer) + ser = CloudPickleSerializer() + ser.dumps(command) + + def process_vars(sc): + broadcast_vars = list(sc._pickled_broadcast_vars) + num_pickled = len(broadcast_vars) + sc._pickled_broadcast_vars.clear() + return num_pickled + + def run(f, sc): + do_pickle(f, sc) + funcs_num_pickled[f] = process_vars(sc) + + # pickle f1, adds b1 to sc._pickled_broadcast_vars in main thread local storage + do_pickle(f1, self.sc) + + # run all for f2, should only add/count/clear b2 from worker thread local storage + t = threading.Thread(target=run, args=(f2, self.sc)) + t.start() + t.join() + + # count number of vars pickled in main thread, only b1 should be counted and cleared + funcs_num_pickled[f1] = process_vars(self.sc) + + self.assertEqual(funcs_num_pickled[f1], 1) + self.assertEqual(funcs_num_pickled[f2], 1) + self.assertEqual(len(list(self.sc._pickled_broadcast_vars)), 0) + + def test_large_closure(self): + N = 200000 + data = [float(i) for i in xrange(N)] + rdd = self.sc.parallelize(range(1), 1).map(lambda x: len(data)) + self.assertEqual(N, rdd.first()) + # regression test for SPARK-6886 + self.assertEqual(1, rdd.map(lambda x: (x, 1)).groupByKey().count()) + + def test_zip_with_different_serializers(self): + a = self.sc.parallelize(range(5)) + b = self.sc.parallelize(range(100, 105)) + self.assertEqual(a.zip(b).collect(), [(0, 100), (1, 101), (2, 102), (3, 103), (4, 104)]) + a = a._reserialize(BatchedSerializer(PickleSerializer(), 2)) + b = b._reserialize(MarshalSerializer()) + self.assertEqual(a.zip(b).collect(), [(0, 100), (1, 101), (2, 102), (3, 103), (4, 104)]) + # regression test for SPARK-4841 + path = os.path.join(SPARK_HOME, "python/test_support/hello/hello.txt") + t = self.sc.textFile(path) + cnt = t.count() + self.assertEqual(cnt, t.zip(t).count()) + rdd = t.map(str) + self.assertEqual(cnt, t.zip(rdd).count()) + # regression test for bug in _reserializer() + self.assertEqual(cnt, t.zip(rdd).count()) + + def test_zip_with_different_object_sizes(self): + # regress test for SPARK-5973 + a = self.sc.parallelize(xrange(10000)).map(lambda i: '*' * i) + b = self.sc.parallelize(xrange(10000, 20000)).map(lambda i: '*' * i) + self.assertEqual(10000, a.zip(b).count()) + + def test_zip_with_different_number_of_items(self): + a = self.sc.parallelize(range(5), 2) + # different number of partitions + b = self.sc.parallelize(range(100, 106), 3) + self.assertRaises(ValueError, lambda: a.zip(b)) + with QuietTest(self.sc): + # different number of batched items in JVM + b = self.sc.parallelize(range(100, 104), 2) + self.assertRaises(Exception, lambda: a.zip(b).count()) + # different number of items in one pair + b = self.sc.parallelize(range(100, 106), 2) + self.assertRaises(Exception, lambda: a.zip(b).count()) + # same total number of items, but different distributions + a = self.sc.parallelize([2, 3], 2).flatMap(range) + b = self.sc.parallelize([3, 2], 2).flatMap(range) + self.assertEqual(a.count(), b.count()) + self.assertRaises(Exception, lambda: a.zip(b).count()) + + def test_count_approx_distinct(self): + rdd = self.sc.parallelize(xrange(1000)) + self.assertTrue(950 < rdd.countApproxDistinct(0.03) < 1050) + self.assertTrue(950 < rdd.map(float).countApproxDistinct(0.03) < 1050) + self.assertTrue(950 < rdd.map(str).countApproxDistinct(0.03) < 1050) + self.assertTrue(950 < rdd.map(lambda x: (x, -x)).countApproxDistinct(0.03) < 1050) + + rdd = self.sc.parallelize([i % 20 for i in range(1000)], 7) + self.assertTrue(18 < rdd.countApproxDistinct() < 22) + self.assertTrue(18 < rdd.map(float).countApproxDistinct() < 22) + self.assertTrue(18 < rdd.map(str).countApproxDistinct() < 22) + self.assertTrue(18 < rdd.map(lambda x: (x, -x)).countApproxDistinct() < 22) + + self.assertRaises(ValueError, lambda: rdd.countApproxDistinct(0.00000001)) + + def test_histogram(self): + # empty + rdd = self.sc.parallelize([]) + self.assertEqual([0], rdd.histogram([0, 10])[1]) + self.assertEqual([0, 0], rdd.histogram([0, 4, 10])[1]) + self.assertRaises(ValueError, lambda: rdd.histogram(1)) + + # out of range + rdd = self.sc.parallelize([10.01, -0.01]) + self.assertEqual([0], rdd.histogram([0, 10])[1]) + self.assertEqual([0, 0], rdd.histogram((0, 4, 10))[1]) + + # in range with one bucket + rdd = self.sc.parallelize(range(1, 5)) + self.assertEqual([4], rdd.histogram([0, 10])[1]) + self.assertEqual([3, 1], rdd.histogram([0, 4, 10])[1]) + + # in range with one bucket exact match + self.assertEqual([4], rdd.histogram([1, 4])[1]) + + # out of range with two buckets + rdd = self.sc.parallelize([10.01, -0.01]) + self.assertEqual([0, 0], rdd.histogram([0, 5, 10])[1]) + + # out of range with two uneven buckets + rdd = self.sc.parallelize([10.01, -0.01]) + self.assertEqual([0, 0], rdd.histogram([0, 4, 10])[1]) + + # in range with two buckets + rdd = self.sc.parallelize([1, 2, 3, 5, 6]) + self.assertEqual([3, 2], rdd.histogram([0, 5, 10])[1]) + + # in range with two bucket and None + rdd = self.sc.parallelize([1, 2, 3, 5, 6, None, float('nan')]) + self.assertEqual([3, 2], rdd.histogram([0, 5, 10])[1]) + + # in range with two uneven buckets + rdd = self.sc.parallelize([1, 2, 3, 5, 6]) + self.assertEqual([3, 2], rdd.histogram([0, 5, 11])[1]) + + # mixed range with two uneven buckets + rdd = self.sc.parallelize([-0.01, 0.0, 1, 2, 3, 5, 6, 11.0, 11.01]) + self.assertEqual([4, 3], rdd.histogram([0, 5, 11])[1]) + + # mixed range with four uneven buckets + rdd = self.sc.parallelize([-0.01, 0.0, 1, 2, 3, 5, 6, 11.01, 12.0, 199.0, 200.0, 200.1]) + self.assertEqual([4, 2, 1, 3], rdd.histogram([0.0, 5.0, 11.0, 12.0, 200.0])[1]) + + # mixed range with uneven buckets and NaN + rdd = self.sc.parallelize([-0.01, 0.0, 1, 2, 3, 5, 6, 11.01, 12.0, + 199.0, 200.0, 200.1, None, float('nan')]) + self.assertEqual([4, 2, 1, 3], rdd.histogram([0.0, 5.0, 11.0, 12.0, 200.0])[1]) + + # out of range with infinite buckets + rdd = self.sc.parallelize([10.01, -0.01, float('nan'), float("inf")]) + self.assertEqual([1, 2], rdd.histogram([float('-inf'), 0, float('inf')])[1]) + + # invalid buckets + self.assertRaises(ValueError, lambda: rdd.histogram([])) + self.assertRaises(ValueError, lambda: rdd.histogram([1])) + self.assertRaises(ValueError, lambda: rdd.histogram(0)) + self.assertRaises(TypeError, lambda: rdd.histogram({})) + + # without buckets + rdd = self.sc.parallelize(range(1, 5)) + self.assertEqual(([1, 4], [4]), rdd.histogram(1)) + + # without buckets single element + rdd = self.sc.parallelize([1]) + self.assertEqual(([1, 1], [1]), rdd.histogram(1)) + + # without bucket no range + rdd = self.sc.parallelize([1] * 4) + self.assertEqual(([1, 1], [4]), rdd.histogram(1)) + + # without buckets basic two + rdd = self.sc.parallelize(range(1, 5)) + self.assertEqual(([1, 2.5, 4], [2, 2]), rdd.histogram(2)) + + # without buckets with more requested than elements + rdd = self.sc.parallelize([1, 2]) + buckets = [1 + 0.2 * i for i in range(6)] + hist = [1, 0, 0, 0, 1] + self.assertEqual((buckets, hist), rdd.histogram(5)) + + # invalid RDDs + rdd = self.sc.parallelize([1, float('inf')]) + self.assertRaises(ValueError, lambda: rdd.histogram(2)) + rdd = self.sc.parallelize([float('nan')]) + self.assertRaises(ValueError, lambda: rdd.histogram(2)) + + # string + rdd = self.sc.parallelize(["ab", "ac", "b", "bd", "ef"], 2) + self.assertEqual([2, 2], rdd.histogram(["a", "b", "c"])[1]) + self.assertEqual((["ab", "ef"], [5]), rdd.histogram(1)) + self.assertRaises(TypeError, lambda: rdd.histogram(2)) + + def test_repartitionAndSortWithinPartitions_asc(self): + rdd = self.sc.parallelize([(0, 5), (3, 8), (2, 6), (0, 8), (3, 8), (1, 3)], 2) + + repartitioned = rdd.repartitionAndSortWithinPartitions(2, lambda key: key % 2, True) + partitions = repartitioned.glom().collect() + self.assertEqual(partitions[0], [(0, 5), (0, 8), (2, 6)]) + self.assertEqual(partitions[1], [(1, 3), (3, 8), (3, 8)]) + + def test_repartitionAndSortWithinPartitions_desc(self): + rdd = self.sc.parallelize([(0, 5), (3, 8), (2, 6), (0, 8), (3, 8), (1, 3)], 2) + + repartitioned = rdd.repartitionAndSortWithinPartitions(2, lambda key: key % 2, False) + partitions = repartitioned.glom().collect() + self.assertEqual(partitions[0], [(2, 6), (0, 5), (0, 8)]) + self.assertEqual(partitions[1], [(3, 8), (3, 8), (1, 3)]) + + def test_repartition_no_skewed(self): + num_partitions = 20 + a = self.sc.parallelize(range(int(1000)), 2) + l = a.repartition(num_partitions).glom().map(len).collect() + zeros = len([x for x in l if x == 0]) + self.assertTrue(zeros == 0) + l = a.coalesce(num_partitions, True).glom().map(len).collect() + zeros = len([x for x in l if x == 0]) + self.assertTrue(zeros == 0) + + def test_repartition_on_textfile(self): + path = os.path.join(SPARK_HOME, "python/test_support/hello/hello.txt") + rdd = self.sc.textFile(path) + result = rdd.repartition(1).collect() + self.assertEqual(u"Hello World!", result[0]) + + def test_distinct(self): + rdd = self.sc.parallelize((1, 2, 3)*10, 10) + self.assertEqual(rdd.getNumPartitions(), 10) + self.assertEqual(rdd.distinct().count(), 3) + result = rdd.distinct(5) + self.assertEqual(result.getNumPartitions(), 5) + self.assertEqual(result.count(), 3) + + def test_external_group_by_key(self): + self.sc._conf.set("spark.python.worker.memory", "1m") + N = 200001 + kv = self.sc.parallelize(xrange(N)).map(lambda x: (x % 3, x)) + gkv = kv.groupByKey().cache() + self.assertEqual(3, gkv.count()) + filtered = gkv.filter(lambda kv: kv[0] == 1) + self.assertEqual(1, filtered.count()) + self.assertEqual([(1, N // 3)], filtered.mapValues(len).collect()) + self.assertEqual([(N // 3, N // 3)], + filtered.values().map(lambda x: (len(x), len(list(x)))).collect()) + result = filtered.collect()[0][1] + self.assertEqual(N // 3, len(result)) + self.assertTrue(isinstance(result.data, shuffle.ExternalListOfList)) + + def test_sort_on_empty_rdd(self): + self.assertEqual([], self.sc.parallelize(zip([], [])).sortByKey().collect()) + + def test_sample(self): + rdd = self.sc.parallelize(range(0, 100), 4) + wo = rdd.sample(False, 0.1, 2).collect() + wo_dup = rdd.sample(False, 0.1, 2).collect() + self.assertSetEqual(set(wo), set(wo_dup)) + wr = rdd.sample(True, 0.2, 5).collect() + wr_dup = rdd.sample(True, 0.2, 5).collect() + self.assertSetEqual(set(wr), set(wr_dup)) + wo_s10 = rdd.sample(False, 0.3, 10).collect() + wo_s20 = rdd.sample(False, 0.3, 20).collect() + self.assertNotEqual(set(wo_s10), set(wo_s20)) + wr_s11 = rdd.sample(True, 0.4, 11).collect() + wr_s21 = rdd.sample(True, 0.4, 21).collect() + self.assertNotEqual(set(wr_s11), set(wr_s21)) + + def test_null_in_rdd(self): + jrdd = self.sc._jvm.PythonUtils.generateRDDWithNull(self.sc._jsc) + rdd = RDD(jrdd, self.sc, UTF8Deserializer()) + self.assertEqual([u"a", None, u"b"], rdd.collect()) + rdd = RDD(jrdd, self.sc, NoOpSerializer()) + self.assertEqual([b"a", None, b"b"], rdd.collect()) + + def test_multiple_python_java_RDD_conversions(self): + # Regression test for SPARK-5361 + data = [ + (u'1', {u'director': u'David Lean'}), + (u'2', {u'director': u'Andrew Dominik'}) + ] + data_rdd = self.sc.parallelize(data) + data_java_rdd = data_rdd._to_java_object_rdd() + data_python_rdd = self.sc._jvm.SerDeUtil.javaToPython(data_java_rdd) + converted_rdd = RDD(data_python_rdd, self.sc) + self.assertEqual(2, converted_rdd.count()) + + # conversion between python and java RDD threw exceptions + data_java_rdd = converted_rdd._to_java_object_rdd() + data_python_rdd = self.sc._jvm.SerDeUtil.javaToPython(data_java_rdd) + converted_rdd = RDD(data_python_rdd, self.sc) + self.assertEqual(2, converted_rdd.count()) + + # Regression test for SPARK-6294 + def test_take_on_jrdd(self): + rdd = self.sc.parallelize(xrange(1 << 20)).map(lambda x: str(x)) + rdd._jrdd.first() + + def test_sortByKey_uses_all_partitions_not_only_first_and_last(self): + # Regression test for SPARK-5969 + seq = [(i * 59 % 101, i) for i in range(101)] # unsorted sequence + rdd = self.sc.parallelize(seq) + for ascending in [True, False]: + sort = rdd.sortByKey(ascending=ascending, numPartitions=5) + self.assertEqual(sort.collect(), sorted(seq, reverse=not ascending)) + sizes = sort.glom().map(len).collect() + for size in sizes: + self.assertGreater(size, 0) + + def test_pipe_functions(self): + data = ['1', '2', '3'] + rdd = self.sc.parallelize(data) + with QuietTest(self.sc): + self.assertEqual([], rdd.pipe('cc').collect()) + self.assertRaises(Py4JJavaError, rdd.pipe('cc', checkCode=True).collect) + result = rdd.pipe('cat').collect() + result.sort() + for x, y in zip(data, result): + self.assertEqual(x, y) + self.assertRaises(Py4JJavaError, rdd.pipe('grep 4', checkCode=True).collect) + self.assertEqual([], rdd.pipe('grep 4').collect()) + + def test_pipe_unicode(self): + # Regression test for SPARK-20947 + data = [u'\u6d4b\u8bd5', '1'] + rdd = self.sc.parallelize(data) + result = rdd.pipe('cat').collect() + self.assertEqual(data, result) + + def test_stopiteration_in_user_code(self): + + def stopit(*x): + raise StopIteration() + + seq_rdd = self.sc.parallelize(range(10)) + keyed_rdd = self.sc.parallelize((x % 2, x) for x in range(10)) + msg = "Caught StopIteration thrown from user's code; failing the task" + + self.assertRaisesRegexp(Py4JJavaError, msg, seq_rdd.map(stopit).collect) + self.assertRaisesRegexp(Py4JJavaError, msg, seq_rdd.filter(stopit).collect) + self.assertRaisesRegexp(Py4JJavaError, msg, seq_rdd.foreach, stopit) + self.assertRaisesRegexp(Py4JJavaError, msg, seq_rdd.reduce, stopit) + self.assertRaisesRegexp(Py4JJavaError, msg, seq_rdd.fold, 0, stopit) + self.assertRaisesRegexp(Py4JJavaError, msg, seq_rdd.foreach, stopit) + self.assertRaisesRegexp(Py4JJavaError, msg, + seq_rdd.cartesian(seq_rdd).flatMap(stopit).collect) + + # these methods call the user function both in the driver and in the executor + # the exception raised is different according to where the StopIteration happens + # RuntimeError is raised if in the driver + # Py4JJavaError is raised if in the executor (wraps the RuntimeError raised in the worker) + self.assertRaisesRegexp((Py4JJavaError, RuntimeError), msg, + keyed_rdd.reduceByKeyLocally, stopit) + self.assertRaisesRegexp((Py4JJavaError, RuntimeError), msg, + seq_rdd.aggregate, 0, stopit, lambda *x: 1) + self.assertRaisesRegexp((Py4JJavaError, RuntimeError), msg, + seq_rdd.aggregate, 0, lambda *x: 1, stopit) + + +if __name__ == "__main__": + import unittest + from pyspark.tests.test_rdd import * + + try: + import xmlrunner + testRunner = xmlrunner.XMLTestRunner(output='target/test-reports') + except ImportError: + testRunner = None + unittest.main(testRunner=testRunner, verbosity=2) --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org