This is an automated email from the ASF dual-hosted git repository.

ruifengz pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 02c645607f43 [SPARK-48512][PYTHON][TESTS] Refactor Python tests
02c645607f43 is described below

commit 02c645607f4353df573cdba568e092c3ff4c359a
Author: Rui Wang <rui.w...@databricks.com>
AuthorDate: Tue Jun 4 17:50:29 2024 +0800

    [SPARK-48512][PYTHON][TESTS] Refactor Python tests
    
    ### What changes were proposed in this pull request?
    
    Use withSQLConf in tests when it is appropriate.
    
    ### Why are the changes needed?
    
    Enforce good practice for setting config in test cases.
    
    ### Does this PR introduce _any_ user-facing change?
    
    NO
    
    ### How was this patch tested?
    
    existing UT
    
    ### Was this patch authored or co-authored using generative AI tooling?
    
    NO
    
    Closes #46852 from amaliujia/refactor_pyspark.
    
    Authored-by: Rui Wang <rui.w...@databricks.com>
    Signed-off-by: Ruifeng Zheng <ruife...@apache.org>
---
 python/pyspark/sql/tests/test_context.py    | 39 +++++++++++++----------------
 python/pyspark/sql/tests/test_readwriter.py | 10 ++------
 python/pyspark/sql/tests/test_types.py      |  5 +---
 3 files changed, 21 insertions(+), 33 deletions(-)

diff --git a/python/pyspark/sql/tests/test_context.py 
b/python/pyspark/sql/tests/test_context.py
index b38183331486..f363b8748c0b 100644
--- a/python/pyspark/sql/tests/test_context.py
+++ b/python/pyspark/sql/tests/test_context.py
@@ -26,13 +26,13 @@ import py4j
 from pyspark import SparkContext, SQLContext
 from pyspark.sql import Row, SparkSession
 from pyspark.sql.types import StructType, StringType, StructField
-from pyspark.testing.utils import ReusedPySparkTestCase
+from pyspark.testing.sqlutils import ReusedSQLTestCase
 
 
-class HiveContextSQLTests(ReusedPySparkTestCase):
+class HiveContextSQLTests(ReusedSQLTestCase):
     @classmethod
     def setUpClass(cls):
-        ReusedPySparkTestCase.setUpClass()
+        ReusedSQLTestCase.setUpClass()
         cls.tempdir = tempfile.NamedTemporaryFile(delete=False)
         cls.hive_available = True
         cls.spark = None
@@ -58,7 +58,7 @@ class HiveContextSQLTests(ReusedPySparkTestCase):
 
     @classmethod
     def tearDownClass(cls):
-        ReusedPySparkTestCase.tearDownClass()
+        ReusedSQLTestCase.tearDownClass()
         shutil.rmtree(cls.tempdir.name, ignore_errors=True)
         if cls.spark is not None:
             cls.spark.stop()
@@ -100,23 +100,20 @@ class HiveContextSQLTests(ReusedPySparkTestCase):
         self.spark.sql("DROP TABLE savedJsonTable")
         self.spark.sql("DROP TABLE externalJsonTable")
 
-        defaultDataSourceName = self.spark.conf.get(
-            "spark.sql.sources.default", "org.apache.spark.sql.parquet"
-        )
-        self.spark.sql("SET 
spark.sql.sources.default=org.apache.spark.sql.json")
-        df.write.saveAsTable("savedJsonTable", path=tmpPath, mode="overwrite")
-        actual = self.spark.catalog.createTable("externalJsonTable", 
path=tmpPath)
-        self.assertEqual(
-            sorted(df.collect()), sorted(self.spark.sql("SELECT * FROM 
savedJsonTable").collect())
-        )
-        self.assertEqual(
-            sorted(df.collect()),
-            sorted(self.spark.sql("SELECT * FROM 
externalJsonTable").collect()),
-        )
-        self.assertEqual(sorted(df.collect()), sorted(actual.collect()))
-        self.spark.sql("DROP TABLE savedJsonTable")
-        self.spark.sql("DROP TABLE externalJsonTable")
-        self.spark.sql("SET spark.sql.sources.default=" + 
defaultDataSourceName)
+        with self.sql_conf({"spark.sql.sources.default": 
"org.apache.spark.sql.json"}):
+            df.write.saveAsTable("savedJsonTable", path=tmpPath, 
mode="overwrite")
+            actual = self.spark.catalog.createTable("externalJsonTable", 
path=tmpPath)
+            self.assertEqual(
+                sorted(df.collect()),
+                sorted(self.spark.sql("SELECT * FROM 
savedJsonTable").collect()),
+            )
+            self.assertEqual(
+                sorted(df.collect()),
+                sorted(self.spark.sql("SELECT * FROM 
externalJsonTable").collect()),
+            )
+            self.assertEqual(sorted(df.collect()), sorted(actual.collect()))
+            self.spark.sql("DROP TABLE savedJsonTable")
+            self.spark.sql("DROP TABLE externalJsonTable")
 
         shutil.rmtree(tmpPath)
 
diff --git a/python/pyspark/sql/tests/test_readwriter.py 
b/python/pyspark/sql/tests/test_readwriter.py
index e752856d0316..8060a9ae8bc7 100644
--- a/python/pyspark/sql/tests/test_readwriter.py
+++ b/python/pyspark/sql/tests/test_readwriter.py
@@ -55,12 +55,9 @@ class ReadwriterTestsMixin:
             )
             self.assertEqual(sorted(df.collect()), sorted(actual.collect()))
 
-            try:
-                self.spark.sql("SET 
spark.sql.sources.default=org.apache.spark.sql.json")
+            with self.sql_conf({"spark.sql.sources.default": 
"org.apache.spark.sql.json"}):
                 actual = self.spark.read.load(path=tmpPath)
                 self.assertEqual(sorted(df.collect()), 
sorted(actual.collect()))
-            finally:
-                self.spark.sql("RESET spark.sql.sources.default")
 
             csvpath = os.path.join(tempfile.mkdtemp(), "data")
             df.write.option("quote", None).format("csv").save(csvpath)
@@ -94,12 +91,9 @@ class ReadwriterTestsMixin:
             )
             self.assertEqual(sorted(df.collect()), sorted(actual.collect()))
 
-            try:
-                self.spark.sql("SET 
spark.sql.sources.default=org.apache.spark.sql.json")
+            with self.sql_conf({"spark.sql.sources.default": 
"org.apache.spark.sql.json"}):
                 actual = self.spark.read.load(path=tmpPath)
                 self.assertEqual(sorted(df.collect()), 
sorted(actual.collect()))
-            finally:
-                self.spark.sql("RESET spark.sql.sources.default")
         finally:
             shutil.rmtree(tmpPath)
 
diff --git a/python/pyspark/sql/tests/test_types.py 
b/python/pyspark/sql/tests/test_types.py
index 80f2c0fcbc03..1882c1fd1f6a 100644
--- a/python/pyspark/sql/tests/test_types.py
+++ b/python/pyspark/sql/tests/test_types.py
@@ -491,14 +491,11 @@ class TypesTestsMixin:
         self.assertEqual(asdict(user), r.asDict())
 
     def test_negative_decimal(self):
-        try:
-            self.spark.sql("set 
spark.sql.legacy.allowNegativeScaleOfDecimal=true")
+        with self.sql_conf({"spark.sql.legacy.allowNegativeScaleOfDecimal": 
True}):
             df = self.spark.createDataFrame([(1,), (11,)], ["value"])
             ret = df.select(F.col("value").cast(DecimalType(1, -1))).collect()
             actual = list(map(lambda r: int(r.value), ret))
             self.assertEqual(actual, [0, 10])
-        finally:
-            self.spark.sql("set 
spark.sql.legacy.allowNegativeScaleOfDecimal=false")
 
     def test_create_dataframe_from_objects(self):
         data = [MyObject(1, "1"), MyObject(2, "2")]


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to