This is an automated email from the ASF dual-hosted git repository.

ruifengz pushed a commit to branch branch-3.5
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/branch-3.5 by this push:
     new e1b7a26d2f4 [SPARK-44750][PYTHON][CONNECT] Apply configuration to 
sparksession during creation
e1b7a26d2f4 is described below

commit e1b7a26d2f48f9f149498a9204db6944c7d5bca3
Author: Michael Zhang <m.zh...@databricks.com>
AuthorDate: Thu Aug 24 08:36:53 2023 +0800

    [SPARK-44750][PYTHON][CONNECT] Apply configuration to sparksession during 
creation
    
    ### What changes were proposed in this pull request?
    
    `SparkSession.Builder` now applies configuration options to the create 
`SparkSession`.
    
    ### Why are the changes needed?
    
    It is reasonable to expect PySpark connect `SparkSession.Builder` to behave 
in the same way as other `SparkSession.Builder`s in Spark Connect. The 
`SparkSession.Builder` should apply the provided configuration options to the 
created `SparkSesssion`.
    
    ### Does this PR introduce _any_ user-facing change?
    
    No.
    
    ### How was this patch tested?
    
    Tests were added to verify that configuration options were applied to the 
`SparkSession`.
    
    Closes #42548 from michaelzhan-db/SPARK-44750.
    
    Lead-authored-by: Michael Zhang <m.zh...@databricks.com>
    Co-authored-by: Ruifeng Zheng <ruife...@foxmail.com>
    Signed-off-by: Ruifeng Zheng <ruife...@apache.org>
    (cherry picked from commit c2e3171f3d3887302227edc39ee124bd61561b7d)
    Signed-off-by: Ruifeng Zheng <ruife...@apache.org>
---
 python/pyspark/sql/connect/session.py               | 10 ++++++++++
 .../pyspark/sql/tests/connect/test_connect_basic.py | 21 +++++++++++++++++++++
 2 files changed, 31 insertions(+)

diff --git a/python/pyspark/sql/connect/session.py 
b/python/pyspark/sql/connect/session.py
index d75a30c561f..2905f7e4269 100644
--- a/python/pyspark/sql/connect/session.py
+++ b/python/pyspark/sql/connect/session.py
@@ -176,6 +176,14 @@ class SparkSession:
                 error_class="NOT_IMPLEMENTED", message_parameters={"feature": 
"enableHiveSupport"}
             )
 
+        def _apply_options(self, session: "SparkSession") -> None:
+            with self._lock:
+                for k, v in self._options.items():
+                    try:
+                        session.conf.set(k, v)
+                    except Exception as e:
+                        warnings.warn(str(e))
+
         def create(self) -> "SparkSession":
             has_channel_builder = self._channel_builder is not None
             has_spark_remote = "spark.remote" in self._options
@@ -200,6 +208,7 @@ class SparkSession:
                 session = SparkSession(connection=spark_remote)
 
             SparkSession._set_default_and_active_session(session)
+            self._apply_options(session)
             return session
 
         def getOrCreate(self) -> "SparkSession":
@@ -209,6 +218,7 @@ class SparkSession:
                     session = SparkSession._default_session
                     if session is None:
                         session = self.create()
+                self._apply_options(session)
                 return session
 
     _client: SparkConnectClient
diff --git a/python/pyspark/sql/tests/connect/test_connect_basic.py 
b/python/pyspark/sql/tests/connect/test_connect_basic.py
index 9e8f5623971..54911c09b6f 100644
--- a/python/pyspark/sql/tests/connect/test_connect_basic.py
+++ b/python/pyspark/sql/tests/connect/test_connect_basic.py
@@ -3347,6 +3347,27 @@ class SparkConnectSessionTests(ReusedConnectTestCase):
             self.assertIn("Create a new SparkSession is only supported with 
SparkConnect.", str(e))
 
 
+class SparkConnectSessionWithOptionsTest(ReusedConnectTestCase):
+    def setUp(self) -> None:
+        self.spark = (
+            PySparkSession.builder.config("string", "foo")
+            .config("integer", 1)
+            .config("boolean", False)
+            .appName(self.__class__.__name__)
+            .remote("local[4]")
+            .getOrCreate()
+        )
+
+    def tearDown(self):
+        self.spark.stop()
+
+    def test_config(self):
+        # Config
+        self.assertEqual(self.spark.conf.get("string"), "foo")
+        self.assertEqual(self.spark.conf.get("boolean"), "false")
+        self.assertEqual(self.spark.conf.get("integer"), "1")
+
+
 @unittest.skipIf(not should_test_connect, connect_requirement_message)
 class ClientTests(unittest.TestCase):
     def test_retry_error_handling(self):


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to