This is an automated email from the ASF dual-hosted git repository. gurwls223 pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push: new ce53fdfa32f [SPARK-44410][PYTHON][CONNECT] Set active session in create, not just getOrCreate ce53fdfa32f is described below commit ce53fdfa32f9fdabdf14df695aba674bb4a377a8 Author: Alice Sayutina <alice.sayut...@databricks.com> AuthorDate: Sun Jul 16 16:31:13 2023 +0900 [SPARK-44410][PYTHON][CONNECT] Set active session in create, not just getOrCreate ### What changes were proposed in this pull request? ML and other uses rely on _active_spark_session to find spark session. Sessions created using getOrCreate method set this variable, but sessions created with create don't. Update create method to set _active_spark_session. ### Why are the changes needed? This breaks spark connect customers, such as pyspark.ml and pandas from finding created session if it was created with create. ### Does this PR introduce _any_ user-facing change? Sessions created by create are set as current session. This is slightly different behavior then before, however this suits interest of almost all clients. The only case it might break is if someone uses mix of both `create` and `getOrCreate` relying on this exact semantic. We can hide it under configuration flag, e.g. `create(set_active_session=False)` if undesired. In this case clients who use `create` and want to use pyspark.ml/pandas will need to update to set it to True. ### How was this patch tested? UT Closes #41987 from cdkrot/spark_session_create_store_session. Authored-by: Alice Sayutina <alice.sayut...@databricks.com> Signed-off-by: Hyukjin Kwon <gurwls...@apache.org> --- python/pyspark/sql/connect/session.py | 11 +++++++++-- python/pyspark/sql/tests/connect/test_connect_basic.py | 4 ++-- python/pyspark/sql/tests/connect/test_session.py | 7 +++++++ 3 files changed, 18 insertions(+), 4 deletions(-) diff --git a/python/pyspark/sql/connect/session.py b/python/pyspark/sql/connect/session.py index 3f9d46a22f4..52eab1bf5f9 100644 --- a/python/pyspark/sql/connect/session.py +++ b/python/pyspark/sql/connect/session.py @@ -93,6 +93,8 @@ if TYPE_CHECKING: # `_active_spark_session` stores the active spark connect session created by # `SparkSession.builder.getOrCreate`. It is used by ML code. +# If sessions are created with `SparkSession.builder.create`, it stores +# The last created session _active_spark_session = None @@ -172,6 +174,8 @@ class SparkSession: ) def create(self) -> "SparkSession": + global _active_spark_session + has_channel_builder = self._channel_builder is not None has_spark_remote = "spark.remote" in self._options @@ -188,11 +192,14 @@ class SparkSession: if has_channel_builder: assert self._channel_builder is not None - return SparkSession(connection=self._channel_builder) + session = SparkSession(connection=self._channel_builder) else: spark_remote = to_str(self._options.get("spark.remote")) assert spark_remote is not None - return SparkSession(connection=spark_remote) + session = SparkSession(connection=spark_remote) + + _active_spark_session = session + return session def getOrCreate(self) -> "SparkSession": global _active_spark_session diff --git a/python/pyspark/sql/tests/connect/test_connect_basic.py b/python/pyspark/sql/tests/connect/test_connect_basic.py index c1235620990..5259ea6b5f5 100644 --- a/python/pyspark/sql/tests/connect/test_connect_basic.py +++ b/python/pyspark/sql/tests/connect/test_connect_basic.py @@ -3323,9 +3323,9 @@ class SparkConnectSessionTests(ReusedConnectTestCase): other = PySparkSession.builder.remote("sc://other.remote:114/").create() self.assertNotEquals(self.spark, other) - # Reuses an active session that was previously created. + # Gets currently active session. same = PySparkSession.builder.remote("sc://other.remote.host:114/").getOrCreate() - self.assertEquals(self.spark, same) + self.assertEquals(other, same) same.stop() # Make sure the environment is clean. diff --git a/python/pyspark/sql/tests/connect/test_session.py b/python/pyspark/sql/tests/connect/test_session.py index bde22d80303..2f14eeddc1e 100644 --- a/python/pyspark/sql/tests/connect/test_session.py +++ b/python/pyspark/sql/tests/connect/test_session.py @@ -63,3 +63,10 @@ class SparkSessionTestCase(unittest.TestCase): self.assertFalse(session.is_stopped) session.stop() self.assertTrue(session.is_stopped) + + def test_session_create_sets_active_session(self): + session = RemoteSparkSession.builder.remote("sc://abc").create() + session2 = RemoteSparkSession.builder.remote("sc://other").getOrCreate() + + self.assertIs(session, session2) + session.stop() --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org