Repository: spark
Updated Branches:
  refs/heads/branch-2.3 2f82c037d -> c854b6ca7


[SPARK-23691][PYTHON][BRANCH-2.3] Use sql_conf util in PySpark tests where 
possible

## What changes were proposed in this pull request?

This PR backports https://github.com/apache/spark/pull/20830 to reduce the diff 
against master and restore the default value back in PySpark tests.

https://github.com/apache/spark/commit/d6632d185e147fcbe6724545488ad80dce20277e 
added an useful util. This backport extracts and brings this util:

```python
contextmanager
def sql_conf(self, pairs):
    ...
```

to allow configuration set/unset within a block:

```python
with self.sql_conf({"spark.blah.blah.blah", "blah"})
    # test codes
```

This PR proposes to use this util where possible in PySpark tests.

Note that there look already few places affecting tests without restoring the 
original value back in unittest classes.

## How was this patch tested?

Likewise, manually tested via:

```
./run-tests --modules=pyspark-sql --python-executables=python2
./run-tests --modules=pyspark-sql --python-executables=python3
```

Author: hyukjinkwon <gurwls...@gmail.com>

Closes #20863 from HyukjinKwon/backport-20830.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/c854b6ca
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/c854b6ca
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/c854b6ca

Branch: refs/heads/branch-2.3
Commit: c854b6ca7ba4dc33138c12ba4606ff8fbe82aef2
Parents: 2f82c03
Author: hyukjinkwon <gurwls...@gmail.com>
Authored: Tue Mar 20 17:53:09 2018 +0900
Committer: hyukjinkwon <gurwls...@gmail.com>
Committed: Tue Mar 20 17:53:09 2018 +0900

----------------------------------------------------------------------
 python/pyspark/sql/tests.py | 143 ++++++++++++++++++++-------------------
 1 file changed, 72 insertions(+), 71 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/c854b6ca/python/pyspark/sql/tests.py
----------------------------------------------------------------------
diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py
index 82c5500..d806e5d 100644
--- a/python/pyspark/sql/tests.py
+++ b/python/pyspark/sql/tests.py
@@ -33,6 +33,7 @@ import datetime
 import array
 import ctypes
 import py4j
+from contextlib import contextmanager
 
 try:
     import xmlrunner
@@ -201,6 +202,28 @@ class ReusedSQLTestCase(ReusedPySparkTestCase):
                "\n\nResult:\n%s\n%s" % (result, result.dtypes))
         self.assertTrue(expected.equals(result), msg=msg)
 
+    @contextmanager
+    def sql_conf(self, pairs):
+        """
+        A convenient context manager to test some configuration specific 
logic. This sets
+        `value` to the configuration `key` and then restores it back when it 
exits.
+        """
+        assert isinstance(pairs, dict), "pairs should be a dictionary."
+
+        keys = pairs.keys()
+        new_values = pairs.values()
+        old_values = [self.spark.conf.get(key, None) for key in keys]
+        for key, new_value in zip(keys, new_values):
+            self.spark.conf.set(key, new_value)
+        try:
+            yield
+        finally:
+            for key, old_value in zip(keys, old_values):
+                if old_value is None:
+                    self.spark.conf.unset(key)
+                else:
+                    self.spark.conf.set(key, old_value)
+
 
 class DataTypeTests(unittest.TestCase):
     # regression test for SPARK-6055
@@ -2409,17 +2432,13 @@ class SQLTests(ReusedSQLTestCase):
         df1 = self.spark.range(1).toDF("a")
         df2 = self.spark.range(1).toDF("b")
 
-        try:
-            self.spark.conf.set("spark.sql.crossJoin.enabled", "false")
+        with self.sql_conf({"spark.sql.crossJoin.enabled": False}):
             self.assertRaises(AnalysisException, lambda: df1.join(df2, 
how="inner").collect())
 
-            self.spark.conf.set("spark.sql.crossJoin.enabled", "true")
+        with self.sql_conf({"spark.sql.crossJoin.enabled": True}):
             actual = df1.join(df2, how="inner").collect()
             expected = [Row(a=0, b=0)]
             self.assertEqual(actual, expected)
-        finally:
-            # We should unset this. Otherwise, other tests are affected.
-            self.spark.conf.unset("spark.sql.crossJoin.enabled")
 
     # Regression test for invalid join methods when on is None, Spark-14761
     def test_invalid_join_method(self):
@@ -2891,21 +2910,18 @@ class SQLTests(ReusedSQLTestCase):
         self.assertPandasEqual(pdf, df.toPandas())
 
         orig_env_tz = os.environ.get('TZ', None)
-        orig_session_tz = self.spark.conf.get('spark.sql.session.timeZone')
         try:
             tz = 'America/Los_Angeles'
             os.environ['TZ'] = tz
             time.tzset()
-            self.spark.conf.set('spark.sql.session.timeZone', tz)
-
-            df = self.spark.createDataFrame(pdf)
-            self.assertPandasEqual(pdf, df.toPandas())
+            with self.sql_conf({'spark.sql.session.timeZone': tz}):
+                df = self.spark.createDataFrame(pdf)
+                self.assertPandasEqual(pdf, df.toPandas())
         finally:
             del os.environ['TZ']
             if orig_env_tz is not None:
                 os.environ['TZ'] = orig_env_tz
             time.tzset()
-            self.spark.conf.set('spark.sql.session.timeZone', orig_session_tz)
 
 
 class HiveSparkSubmitTests(SparkSubmitTests):
@@ -3472,12 +3488,11 @@ class ArrowTests(ReusedSQLTestCase):
         self.assertTrue(all([c == 1 for c in null_counts]))
 
     def _toPandas_arrow_toggle(self, df):
-        self.spark.conf.set("spark.sql.execution.arrow.enabled", "false")
-        try:
+        with self.sql_conf({"spark.sql.execution.arrow.enabled": False}):
             pdf = df.toPandas()
-        finally:
-            self.spark.conf.set("spark.sql.execution.arrow.enabled", "true")
+
         pdf_arrow = df.toPandas()
+
         return pdf, pdf_arrow
 
     def test_toPandas_arrow_toggle(self):
@@ -3489,16 +3504,17 @@ class ArrowTests(ReusedSQLTestCase):
 
     def test_toPandas_respect_session_timezone(self):
         df = self.spark.createDataFrame(self.data, schema=self.schema)
-        orig_tz = self.spark.conf.get("spark.sql.session.timeZone")
-        try:
-            timezone = "America/New_York"
-            self.spark.conf.set("spark.sql.session.timeZone", timezone)
-            
self.spark.conf.set("spark.sql.execution.pandas.respectSessionTimeZone", 
"false")
-            try:
-                pdf_la, pdf_arrow_la = self._toPandas_arrow_toggle(df)
-                self.assertPandasEqual(pdf_arrow_la, pdf_la)
-            finally:
-                
self.spark.conf.set("spark.sql.execution.pandas.respectSessionTimeZone", "true")
+
+        timezone = "America/New_York"
+        with self.sql_conf({
+                "spark.sql.execution.pandas.respectSessionTimeZone": False,
+                "spark.sql.session.timeZone": timezone}):
+            pdf_la, pdf_arrow_la = self._toPandas_arrow_toggle(df)
+            self.assertPandasEqual(pdf_arrow_la, pdf_la)
+
+        with self.sql_conf({
+                "spark.sql.execution.pandas.respectSessionTimeZone": True,
+                "spark.sql.session.timeZone": timezone}):
             pdf_ny, pdf_arrow_ny = self._toPandas_arrow_toggle(df)
             self.assertPandasEqual(pdf_arrow_ny, pdf_ny)
 
@@ -3511,8 +3527,6 @@ class ArrowTests(ReusedSQLTestCase):
                     pdf_la_corrected[field.name] = 
_check_series_convert_timestamps_local_tz(
                         pdf_la_corrected[field.name], timezone)
             self.assertPandasEqual(pdf_ny, pdf_la_corrected)
-        finally:
-            self.spark.conf.set("spark.sql.session.timeZone", orig_tz)
 
     def test_pandas_round_trip(self):
         pdf = self.create_pandas_data_frame()
@@ -3528,12 +3542,11 @@ class ArrowTests(ReusedSQLTestCase):
         self.assertTrue(pdf.empty)
 
     def _createDataFrame_toggle(self, pdf, schema=None):
-        self.spark.conf.set("spark.sql.execution.arrow.enabled", "false")
-        try:
+        with self.sql_conf({"spark.sql.execution.arrow.enabled": False}):
             df_no_arrow = self.spark.createDataFrame(pdf, schema=schema)
-        finally:
-            self.spark.conf.set("spark.sql.execution.arrow.enabled", "true")
+
         df_arrow = self.spark.createDataFrame(pdf, schema=schema)
+
         return df_no_arrow, df_arrow
 
     def test_createDataFrame_toggle(self):
@@ -3544,18 +3557,18 @@ class ArrowTests(ReusedSQLTestCase):
     def test_createDataFrame_respect_session_timezone(self):
         from datetime import timedelta
         pdf = self.create_pandas_data_frame()
-        orig_tz = self.spark.conf.get("spark.sql.session.timeZone")
-        try:
-            timezone = "America/New_York"
-            self.spark.conf.set("spark.sql.session.timeZone", timezone)
-            
self.spark.conf.set("spark.sql.execution.pandas.respectSessionTimeZone", 
"false")
-            try:
-                df_no_arrow_la, df_arrow_la = 
self._createDataFrame_toggle(pdf, schema=self.schema)
-                result_la = df_no_arrow_la.collect()
-                result_arrow_la = df_arrow_la.collect()
-                self.assertEqual(result_la, result_arrow_la)
-            finally:
-                
self.spark.conf.set("spark.sql.execution.pandas.respectSessionTimeZone", "true")
+        timezone = "America/New_York"
+        with self.sql_conf({
+                "spark.sql.execution.pandas.respectSessionTimeZone": False,
+                "spark.sql.session.timeZone": timezone}):
+            df_no_arrow_la, df_arrow_la = self._createDataFrame_toggle(pdf, 
schema=self.schema)
+            result_la = df_no_arrow_la.collect()
+            result_arrow_la = df_arrow_la.collect()
+            self.assertEqual(result_la, result_arrow_la)
+
+        with self.sql_conf({
+                "spark.sql.execution.pandas.respectSessionTimeZone": True,
+                "spark.sql.session.timeZone": timezone}):
             df_no_arrow_ny, df_arrow_ny = self._createDataFrame_toggle(pdf, 
schema=self.schema)
             result_ny = df_no_arrow_ny.collect()
             result_arrow_ny = df_arrow_ny.collect()
@@ -3568,8 +3581,6 @@ class ArrowTests(ReusedSQLTestCase):
                                           for k, v in row.asDict().items()})
                                    for row in result_la]
             self.assertEqual(result_ny, result_la_corrected)
-        finally:
-            self.spark.conf.set("spark.sql.session.timeZone", orig_tz)
 
     def test_createDataFrame_with_schema(self):
         pdf = self.create_pandas_data_frame()
@@ -4222,9 +4233,7 @@ class ScalarPandasUDFTests(ReusedSQLTestCase):
     def test_vectorized_udf_check_config(self):
         from pyspark.sql.functions import pandas_udf, col
         import pandas as pd
-        orig_value = 
self.spark.conf.get("spark.sql.execution.arrow.maxRecordsPerBatch", None)
-        self.spark.conf.set("spark.sql.execution.arrow.maxRecordsPerBatch", 3)
-        try:
+        with self.sql_conf({"spark.sql.execution.arrow.maxRecordsPerBatch": 
3}):
             df = self.spark.range(10, numPartitions=1)
 
             @pandas_udf(returnType=LongType())
@@ -4234,11 +4243,6 @@ class ScalarPandasUDFTests(ReusedSQLTestCase):
             result = df.select(check_records_per_batch(col("id"))).collect()
             for (r,) in result:
                 self.assertTrue(r <= 3)
-        finally:
-            if orig_value is None:
-                
self.spark.conf.unset("spark.sql.execution.arrow.maxRecordsPerBatch")
-            else:
-                
self.spark.conf.set("spark.sql.execution.arrow.maxRecordsPerBatch", orig_value)
 
     def test_vectorized_udf_timestamps_respect_session_timezone(self):
         from pyspark.sql.functions import pandas_udf, col
@@ -4257,30 +4261,27 @@ class ScalarPandasUDFTests(ReusedSQLTestCase):
         internal_value = pandas_udf(
             lambda ts: ts.apply(lambda ts: ts.value if ts is not pd.NaT else 
None), LongType())
 
-        orig_tz = self.spark.conf.get("spark.sql.session.timeZone")
-        try:
-            timezone = "America/New_York"
-            self.spark.conf.set("spark.sql.session.timeZone", timezone)
-            
self.spark.conf.set("spark.sql.execution.pandas.respectSessionTimeZone", 
"false")
-            try:
-                df_la = df.withColumn("tscopy", 
f_timestamp_copy(col("timestamp"))) \
-                    .withColumn("internal_value", 
internal_value(col("timestamp")))
-                result_la = df_la.select(col("idx"), 
col("internal_value")).collect()
-                # Correct result_la by adjusting 3 hours difference between 
Los Angeles and New York
-                diff = 3 * 60 * 60 * 1000 * 1000 * 1000
-                result_la_corrected = \
-                    df_la.select(col("idx"), col("tscopy"), 
col("internal_value") + diff).collect()
-            finally:
-                
self.spark.conf.set("spark.sql.execution.pandas.respectSessionTimeZone", "true")
+        timezone = "America/New_York"
+        with self.sql_conf({
+                "spark.sql.execution.pandas.respectSessionTimeZone": False,
+                "spark.sql.session.timeZone": timezone}):
+            df_la = df.withColumn("tscopy", 
f_timestamp_copy(col("timestamp"))) \
+                .withColumn("internal_value", internal_value(col("timestamp")))
+            result_la = df_la.select(col("idx"), 
col("internal_value")).collect()
+            # Correct result_la by adjusting 3 hours difference between Los 
Angeles and New York
+            diff = 3 * 60 * 60 * 1000 * 1000 * 1000
+            result_la_corrected = \
+                df_la.select(col("idx"), col("tscopy"), col("internal_value") 
+ diff).collect()
 
+        with self.sql_conf({
+                "spark.sql.execution.pandas.respectSessionTimeZone": True,
+                "spark.sql.session.timeZone": timezone}):
             df_ny = df.withColumn("tscopy", 
f_timestamp_copy(col("timestamp"))) \
                 .withColumn("internal_value", internal_value(col("timestamp")))
             result_ny = df_ny.select(col("idx"), col("tscopy"), 
col("internal_value")).collect()
 
             self.assertNotEqual(result_ny, result_la)
             self.assertEqual(result_ny, result_la_corrected)
-        finally:
-            self.spark.conf.set("spark.sql.session.timeZone", orig_tz)
 
     def test_nondeterministic_vectorized_udf(self):
         # Test that nondeterministic UDFs are evaluated only once in chained 
UDF evaluations


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to