This is an automated email from the ASF dual-hosted git repository. ruifengz pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push: new 02c645607f43 [SPARK-48512][PYTHON][TESTS] Refactor Python tests 02c645607f43 is described below commit 02c645607f4353df573cdba568e092c3ff4c359a Author: Rui Wang <rui.w...@databricks.com> AuthorDate: Tue Jun 4 17:50:29 2024 +0800 [SPARK-48512][PYTHON][TESTS] Refactor Python tests ### What changes were proposed in this pull request? Use withSQLConf in tests when it is appropriate. ### Why are the changes needed? Enforce good practice for setting config in test cases. ### Does this PR introduce _any_ user-facing change? NO ### How was this patch tested? existing UT ### Was this patch authored or co-authored using generative AI tooling? NO Closes #46852 from amaliujia/refactor_pyspark. Authored-by: Rui Wang <rui.w...@databricks.com> Signed-off-by: Ruifeng Zheng <ruife...@apache.org> --- python/pyspark/sql/tests/test_context.py | 39 +++++++++++++---------------- python/pyspark/sql/tests/test_readwriter.py | 10 ++------ python/pyspark/sql/tests/test_types.py | 5 +--- 3 files changed, 21 insertions(+), 33 deletions(-) diff --git a/python/pyspark/sql/tests/test_context.py b/python/pyspark/sql/tests/test_context.py index b38183331486..f363b8748c0b 100644 --- a/python/pyspark/sql/tests/test_context.py +++ b/python/pyspark/sql/tests/test_context.py @@ -26,13 +26,13 @@ import py4j from pyspark import SparkContext, SQLContext from pyspark.sql import Row, SparkSession from pyspark.sql.types import StructType, StringType, StructField -from pyspark.testing.utils import ReusedPySparkTestCase +from pyspark.testing.sqlutils import ReusedSQLTestCase -class HiveContextSQLTests(ReusedPySparkTestCase): +class HiveContextSQLTests(ReusedSQLTestCase): @classmethod def setUpClass(cls): - ReusedPySparkTestCase.setUpClass() + ReusedSQLTestCase.setUpClass() cls.tempdir = tempfile.NamedTemporaryFile(delete=False) cls.hive_available = True cls.spark = None @@ -58,7 +58,7 @@ class HiveContextSQLTests(ReusedPySparkTestCase): @classmethod def tearDownClass(cls): - ReusedPySparkTestCase.tearDownClass() + ReusedSQLTestCase.tearDownClass() shutil.rmtree(cls.tempdir.name, ignore_errors=True) if cls.spark is not None: cls.spark.stop() @@ -100,23 +100,20 @@ class HiveContextSQLTests(ReusedPySparkTestCase): self.spark.sql("DROP TABLE savedJsonTable") self.spark.sql("DROP TABLE externalJsonTable") - defaultDataSourceName = self.spark.conf.get( - "spark.sql.sources.default", "org.apache.spark.sql.parquet" - ) - self.spark.sql("SET spark.sql.sources.default=org.apache.spark.sql.json") - df.write.saveAsTable("savedJsonTable", path=tmpPath, mode="overwrite") - actual = self.spark.catalog.createTable("externalJsonTable", path=tmpPath) - self.assertEqual( - sorted(df.collect()), sorted(self.spark.sql("SELECT * FROM savedJsonTable").collect()) - ) - self.assertEqual( - sorted(df.collect()), - sorted(self.spark.sql("SELECT * FROM externalJsonTable").collect()), - ) - self.assertEqual(sorted(df.collect()), sorted(actual.collect())) - self.spark.sql("DROP TABLE savedJsonTable") - self.spark.sql("DROP TABLE externalJsonTable") - self.spark.sql("SET spark.sql.sources.default=" + defaultDataSourceName) + with self.sql_conf({"spark.sql.sources.default": "org.apache.spark.sql.json"}): + df.write.saveAsTable("savedJsonTable", path=tmpPath, mode="overwrite") + actual = self.spark.catalog.createTable("externalJsonTable", path=tmpPath) + self.assertEqual( + sorted(df.collect()), + sorted(self.spark.sql("SELECT * FROM savedJsonTable").collect()), + ) + self.assertEqual( + sorted(df.collect()), + sorted(self.spark.sql("SELECT * FROM externalJsonTable").collect()), + ) + self.assertEqual(sorted(df.collect()), sorted(actual.collect())) + self.spark.sql("DROP TABLE savedJsonTable") + self.spark.sql("DROP TABLE externalJsonTable") shutil.rmtree(tmpPath) diff --git a/python/pyspark/sql/tests/test_readwriter.py b/python/pyspark/sql/tests/test_readwriter.py index e752856d0316..8060a9ae8bc7 100644 --- a/python/pyspark/sql/tests/test_readwriter.py +++ b/python/pyspark/sql/tests/test_readwriter.py @@ -55,12 +55,9 @@ class ReadwriterTestsMixin: ) self.assertEqual(sorted(df.collect()), sorted(actual.collect())) - try: - self.spark.sql("SET spark.sql.sources.default=org.apache.spark.sql.json") + with self.sql_conf({"spark.sql.sources.default": "org.apache.spark.sql.json"}): actual = self.spark.read.load(path=tmpPath) self.assertEqual(sorted(df.collect()), sorted(actual.collect())) - finally: - self.spark.sql("RESET spark.sql.sources.default") csvpath = os.path.join(tempfile.mkdtemp(), "data") df.write.option("quote", None).format("csv").save(csvpath) @@ -94,12 +91,9 @@ class ReadwriterTestsMixin: ) self.assertEqual(sorted(df.collect()), sorted(actual.collect())) - try: - self.spark.sql("SET spark.sql.sources.default=org.apache.spark.sql.json") + with self.sql_conf({"spark.sql.sources.default": "org.apache.spark.sql.json"}): actual = self.spark.read.load(path=tmpPath) self.assertEqual(sorted(df.collect()), sorted(actual.collect())) - finally: - self.spark.sql("RESET spark.sql.sources.default") finally: shutil.rmtree(tmpPath) diff --git a/python/pyspark/sql/tests/test_types.py b/python/pyspark/sql/tests/test_types.py index 80f2c0fcbc03..1882c1fd1f6a 100644 --- a/python/pyspark/sql/tests/test_types.py +++ b/python/pyspark/sql/tests/test_types.py @@ -491,14 +491,11 @@ class TypesTestsMixin: self.assertEqual(asdict(user), r.asDict()) def test_negative_decimal(self): - try: - self.spark.sql("set spark.sql.legacy.allowNegativeScaleOfDecimal=true") + with self.sql_conf({"spark.sql.legacy.allowNegativeScaleOfDecimal": True}): df = self.spark.createDataFrame([(1,), (11,)], ["value"]) ret = df.select(F.col("value").cast(DecimalType(1, -1))).collect() actual = list(map(lambda r: int(r.value), ret)) self.assertEqual(actual, [0, 10]) - finally: - self.spark.sql("set spark.sql.legacy.allowNegativeScaleOfDecimal=false") def test_create_dataframe_from_objects(self): data = [MyObject(1, "1"), MyObject(2, "2")] --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org