Hyukjin Kwon created SPARK-22379:
------------------------------------

             Summary: Reduce duplication setUpClass and tearDownClass in 
PySpark SQL tests
                 Key: SPARK-22379
                 URL: https://issues.apache.org/jira/browse/SPARK-22379
             Project: Spark
          Issue Type: Improvement
          Components: PySpark
    Affects Versions: 2.3.0
            Reporter: Hyukjin Kwon
            Priority: Trivial


Looks there are some duplication in sql/tests.py:

{code}
diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py
index 98afae662b4..6812da6b309 100644
--- a/python/pyspark/sql/tests.py
+++ b/python/pyspark/sql/tests.py
@@ -179,6 +179,18 @@ class MyObject(object):
         self.value = value


+class ReusedSQLTestCase(ReusedPySparkTestCase):
+    @classmethod
+    def setUpClass(cls):
+        ReusedPySparkTestCase.setUpClass()
+        cls.spark = SparkSession(cls.sc)
+
+    @classmethod
+    def tearDownClass(cls):
+        ReusedPySparkTestCase.tearDownClass()
+        cls.spark.stop()
+
+
 class DataTypeTests(unittest.TestCase):
     # regression test for SPARK-6055
     def test_data_type_eq(self):
@@ -214,21 +226,19 @@ class DataTypeTests(unittest.TestCase):
         self.assertRaises(TypeError, struct_field.typeName)


-class SQLTests(ReusedPySparkTestCase):
+class SQLTests(ReusedSQLTestCase):

     @classmethod
     def setUpClass(cls):
-        ReusedPySparkTestCase.setUpClass()
+        ReusedSQLTestCase.setUpClass()
         cls.tempdir = tempfile.NamedTemporaryFile(delete=False)
         os.unlink(cls.tempdir.name)
-        cls.spark = SparkSession(cls.sc)
         cls.testData = [Row(key=i, value=str(i)) for i in range(100)]
         cls.df = cls.spark.createDataFrame(cls.testData)

     @classmethod
     def tearDownClass(cls):
-        ReusedPySparkTestCase.tearDownClass()
-        cls.spark.stop()
+        ReusedSQLTestCase.tearDownClass()
         shutil.rmtree(cls.tempdir.name, ignore_errors=True)

     def test_sqlcontext_reuses_sparksession(self):
@@ -2623,17 +2633,7 @@ class HiveSparkSubmitTests(SparkSubmitTests):
         self.assertTrue(os.path.exists(metastore_path))


-class SQLTests2(ReusedPySparkTestCase):
-
-    @classmethod
-    def setUpClass(cls):
-        ReusedPySparkTestCase.setUpClass()
-        cls.spark = SparkSession(cls.sc)
-
-    @classmethod
-    def tearDownClass(cls):
-        ReusedPySparkTestCase.tearDownClass()
-        cls.spark.stop()
+class SQLTests2(ReusedSQLTestCase):

     # We can't include this test into SQLTests because we will stop class's 
SparkContext and cause
     # other tests failed.
@@ -3082,12 +3082,12 @@ class DataTypeVerificationTests(unittest.TestCase):


 @unittest.skipIf(not _have_arrow, "Arrow not installed")
-class ArrowTests(ReusedPySparkTestCase):
+class ArrowTests(ReusedSQLTestCase):

     @classmethod
     def setUpClass(cls):
         from datetime import datetime
-        ReusedPySparkTestCase.setUpClass()
+        ReusedSQLTestCase.setUpClass()

         # Synchronize default timezone between Python and Java
         cls.tz_prev = os.environ.get("TZ", None)  # save current tz if set
@@ -3095,7 +3095,6 @@ class ArrowTests(ReusedPySparkTestCase):
         os.environ["TZ"] = tz
         time.tzset()

-        cls.spark = SparkSession(cls.sc)
         cls.spark.conf.set("spark.sql.session.timeZone", tz)
         cls.spark.conf.set("spark.sql.execution.arrow.enabled", "true")
         cls.schema = StructType([
@@ -3116,8 +3115,7 @@ class ArrowTests(ReusedPySparkTestCase):
         if cls.tz_prev is not None:
             os.environ["TZ"] = cls.tz_prev
         time.tzset()
-        ReusedPySparkTestCase.tearDownClass()
-        cls.spark.stop()
+        ReusedSQLTestCase.tearDownClass()

     def assertFramesEqual(self, df_with_arrow, df_without):
         msg = ("DataFrame from Arrow is not equal" +
@@ -3169,17 +3167,7 @@ class ArrowTests(ReusedPySparkTestCase):


 @unittest.skipIf(not _have_pandas or not _have_arrow, "Pandas or Arrow not 
installed")
-class VectorizedUDFTests(ReusedPySparkTestCase):
-
-    @classmethod
-    def setUpClass(cls):
-        ReusedPySparkTestCase.setUpClass()
-        cls.spark = SparkSession(cls.sc)
-
-    @classmethod
-    def tearDownClass(cls):
-        ReusedPySparkTestCase.tearDownClass()
-        cls.spark.stop()
+class VectorizedUDFTests(ReusedSQLTestCase):

     def test_vectorized_udf_basic(self):
         from pyspark.sql.functions import pandas_udf, col
@@ -3478,16 +3466,7 @@ class VectorizedUDFTests(ReusedPySparkTestCase):


 @unittest.skipIf(not _have_pandas or not _have_arrow, "Pandas or Arrow not 
installed")
-class GroupbyApplyTests(ReusedPySparkTestCase):
-    @classmethod
-    def setUpClass(cls):
-        ReusedPySparkTestCase.setUpClass()
-        cls.spark = SparkSession(cls.sc)
-
-    @classmethod
-    def tearDownClass(cls):
-        ReusedPySparkTestCase.tearDownClass()
-        cls.spark.stop()
+class GroupbyApplyTests(ReusedSQLTestCase):

     def assertFramesEqual(self, expected, result):
         msg = ("DataFrames are not equal: " +
{code}

Looks we can easily deduplicate it.



--
This message was sent by Atlassian JIRA
(v6.4.14#64029)

---------------------------------------------------------------------
To unsubscribe, e-mail: issues-unsubscr...@spark.apache.org
For additional commands, e-mail: issues-h...@spark.apache.org

Reply via email to