Hyukjin Kwon created SPARK-22379: ------------------------------------ Summary: Reduce duplication setUpClass and tearDownClass in PySpark SQL tests Key: SPARK-22379 URL: https://issues.apache.org/jira/browse/SPARK-22379 Project: Spark Issue Type: Improvement Components: PySpark Affects Versions: 2.3.0 Reporter: Hyukjin Kwon Priority: Trivial
Looks there are some duplication in sql/tests.py: {code} diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index 98afae662b4..6812da6b309 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -179,6 +179,18 @@ class MyObject(object): self.value = value +class ReusedSQLTestCase(ReusedPySparkTestCase): + @classmethod + def setUpClass(cls): + ReusedPySparkTestCase.setUpClass() + cls.spark = SparkSession(cls.sc) + + @classmethod + def tearDownClass(cls): + ReusedPySparkTestCase.tearDownClass() + cls.spark.stop() + + class DataTypeTests(unittest.TestCase): # regression test for SPARK-6055 def test_data_type_eq(self): @@ -214,21 +226,19 @@ class DataTypeTests(unittest.TestCase): self.assertRaises(TypeError, struct_field.typeName) -class SQLTests(ReusedPySparkTestCase): +class SQLTests(ReusedSQLTestCase): @classmethod def setUpClass(cls): - ReusedPySparkTestCase.setUpClass() + ReusedSQLTestCase.setUpClass() cls.tempdir = tempfile.NamedTemporaryFile(delete=False) os.unlink(cls.tempdir.name) - cls.spark = SparkSession(cls.sc) cls.testData = [Row(key=i, value=str(i)) for i in range(100)] cls.df = cls.spark.createDataFrame(cls.testData) @classmethod def tearDownClass(cls): - ReusedPySparkTestCase.tearDownClass() - cls.spark.stop() + ReusedSQLTestCase.tearDownClass() shutil.rmtree(cls.tempdir.name, ignore_errors=True) def test_sqlcontext_reuses_sparksession(self): @@ -2623,17 +2633,7 @@ class HiveSparkSubmitTests(SparkSubmitTests): self.assertTrue(os.path.exists(metastore_path)) -class SQLTests2(ReusedPySparkTestCase): - - @classmethod - def setUpClass(cls): - ReusedPySparkTestCase.setUpClass() - cls.spark = SparkSession(cls.sc) - - @classmethod - def tearDownClass(cls): - ReusedPySparkTestCase.tearDownClass() - cls.spark.stop() +class SQLTests2(ReusedSQLTestCase): # We can't include this test into SQLTests because we will stop class's SparkContext and cause # other tests failed. @@ -3082,12 +3082,12 @@ class DataTypeVerificationTests(unittest.TestCase): @unittest.skipIf(not _have_arrow, "Arrow not installed") -class ArrowTests(ReusedPySparkTestCase): +class ArrowTests(ReusedSQLTestCase): @classmethod def setUpClass(cls): from datetime import datetime - ReusedPySparkTestCase.setUpClass() + ReusedSQLTestCase.setUpClass() # Synchronize default timezone between Python and Java cls.tz_prev = os.environ.get("TZ", None) # save current tz if set @@ -3095,7 +3095,6 @@ class ArrowTests(ReusedPySparkTestCase): os.environ["TZ"] = tz time.tzset() - cls.spark = SparkSession(cls.sc) cls.spark.conf.set("spark.sql.session.timeZone", tz) cls.spark.conf.set("spark.sql.execution.arrow.enabled", "true") cls.schema = StructType([ @@ -3116,8 +3115,7 @@ class ArrowTests(ReusedPySparkTestCase): if cls.tz_prev is not None: os.environ["TZ"] = cls.tz_prev time.tzset() - ReusedPySparkTestCase.tearDownClass() - cls.spark.stop() + ReusedSQLTestCase.tearDownClass() def assertFramesEqual(self, df_with_arrow, df_without): msg = ("DataFrame from Arrow is not equal" + @@ -3169,17 +3167,7 @@ class ArrowTests(ReusedPySparkTestCase): @unittest.skipIf(not _have_pandas or not _have_arrow, "Pandas or Arrow not installed") -class VectorizedUDFTests(ReusedPySparkTestCase): - - @classmethod - def setUpClass(cls): - ReusedPySparkTestCase.setUpClass() - cls.spark = SparkSession(cls.sc) - - @classmethod - def tearDownClass(cls): - ReusedPySparkTestCase.tearDownClass() - cls.spark.stop() +class VectorizedUDFTests(ReusedSQLTestCase): def test_vectorized_udf_basic(self): from pyspark.sql.functions import pandas_udf, col @@ -3478,16 +3466,7 @@ class VectorizedUDFTests(ReusedPySparkTestCase): @unittest.skipIf(not _have_pandas or not _have_arrow, "Pandas or Arrow not installed") -class GroupbyApplyTests(ReusedPySparkTestCase): - @classmethod - def setUpClass(cls): - ReusedPySparkTestCase.setUpClass() - cls.spark = SparkSession(cls.sc) - - @classmethod - def tearDownClass(cls): - ReusedPySparkTestCase.tearDownClass() - cls.spark.stop() +class GroupbyApplyTests(ReusedSQLTestCase): def assertFramesEqual(self, expected, result): msg = ("DataFrames are not equal: " + {code} Looks we can easily deduplicate it. -- This message was sent by Atlassian JIRA (v6.4.14#64029) --------------------------------------------------------------------- To unsubscribe, e-mail: issues-unsubscr...@spark.apache.org For additional commands, e-mail: issues-h...@spark.apache.org