This is an automated email from the ASF dual-hosted git repository.

gurwls223 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 56c7cf33929 [SPARK-41933][CONNECT] Provide local mode that 
automatically starts the server
56c7cf33929 is described below

commit 56c7cf33929d7d42b7d299c0bb7e895963241214
Author: Hyukjin Kwon <gurwls...@apache.org>
AuthorDate: Sun Jan 8 16:49:15 2023 +0900

    [SPARK-41933][CONNECT] Provide local mode that automatically starts the 
server
    
    ### What changes were proposed in this pull request?
    
    This PR proposes local mode for Spark Connect. It automatically starts a 
Spark session (with bypassing `local*` master string) that launches the Spark 
Connect server, which introduces two user-facing changes below.
    
    Notice that local mode exactly follows the regular PySpark session's stop 
behavior by terminating the server (whereas non-local mode would not close the 
server and other sessions). See also the newly added comments for 
`pyspark.sql.connect.SparkSession.stop`.
    
    #### Local build of Apache Spark (for developers)
    
    Automatically finds the jars for Spark Connect (because the jars for Spark 
Connect are not bundled in the regular Apache Spark release).
    
    - PySpark shell
        ```bash
        pyspark --remote local
        ```
    
    - PySpark application submission
        ```bash
        spark-submit --remote "local[4]" app.py
        ```
    
    - Use it as a Python library
        ```python
        from pyspark.sql import SparkSession
        SparkSession.builder.remote("local[*]").getOrCreate()
        ```
    
    #### Official release of Apache Spark (for end-users)
    
    Users must specify jars or packages. Jars aren't automatically searched.
    
    - PySpark shell
        ```bash
        pyspark --packages org.apache.spark:spark-connect_2.12:3.4.0 --remote 
local
        ```
    
    - PySpark application submission
        ```bash
        spark-submit --packages org.apache.spark:spark-connect_2.12:3.4.0 
--remote "local[4]" app.py
        ```
    
    - Use it as a Python library
        ```python
        from pyspark.sql import SparkSession
        SparkSession.builder.config(
            "spark.jars.packages", "org.apache.spark:spark-connect_2.12:3.4.0"
        ).remote("local[*]").getOrCreate()
        ```
    
    ### Why are the changes needed?
    
    In order to provide an easier mode to try Spark Connect for both developers 
and end-users.
    
    ### Does this PR introduce _any_ user-facing change?
    
    No to end users because Spark Connect has not been released.
    To the dev, yes. See the examples above.
    
    ### How was this patch tested?
    
    Unittests were refactored to use/test this feature (that also deduplicated 
the codes).
    
    Closes #39441 from HyukjinKwon/SPARK-41933.
    
    Authored-by: Hyukjin Kwon <gurwls...@apache.org>
    Signed-off-by: Hyukjin Kwon <gurwls...@apache.org>
---
 python/docs/source/development/testing.rst         |  20 +--
 python/pyspark/context.py                          |   2 +-
 python/pyspark/sql/connect/catalog.py              |  15 +-
 python/pyspark/sql/connect/column.py               |  14 +-
 python/pyspark/sql/connect/dataframe.py            |  18 +-
 python/pyspark/sql/connect/functions.py            |  17 +-
 python/pyspark/sql/connect/group.py                |  15 +-
 python/pyspark/sql/connect/readwriter.py           |  16 +-
 python/pyspark/sql/connect/session.py              | 195 ++++++++++++++++++---
 python/pyspark/sql/connect/window.py               |  15 +-
 python/pyspark/sql/session.py                      |  59 ++++---
 .../sql/tests/connect/test_parity_catalog.py       |  23 +--
 .../sql/tests/connect/test_parity_dataframe.py     |  25 +--
 .../sql/tests/connect/test_parity_functions.py     |  28 +--
 python/pyspark/sql/tests/test_dataframe.py         |   2 +-
 python/pyspark/testing/connectutils.py             |  25 +++
 .../spark/sql/api/python/PythonSQLUtils.scala      |  14 ++
 17 files changed, 302 insertions(+), 201 deletions(-)

diff --git a/python/docs/source/development/testing.rst 
b/python/docs/source/development/testing.rst
index 0262c318cd6..e8255cab8c8 100644
--- a/python/docs/source/development/testing.rst
+++ b/python/docs/source/development/testing.rst
@@ -82,26 +82,14 @@ you should regenerate Python Protobuf client by running 
``dev/connect-gen-protos
 Running PySpark Shell with Python Client
 ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
 
-To run Spark Connect server you locally built:
+For Apache Spark you locally built:
 
 .. code-block:: bash
 
-    bin/spark-shell \
-      --jars `ls connector/connect/target/**/spark-connect*SNAPSHOT.jar | 
paste -sd ',' -` \
-      --conf spark.plugins=org.apache.spark.sql.connect.SparkConnectPlugin
+    bin/pyspark --remote "local[*]"
 
-To run the Spark Connect server from the Apache Spark release:
+For the Apache Spark release:
 
 .. code-block:: bash
 
-    bin/spark-shell \
-      --packages org.apache.spark:spark-connect_2.12:3.4.0 \
-      --conf spark.plugins=org.apache.spark.sql.connect.SparkConnectPlugin
-
-
-To run the PySpark Shell with the client for the Spark Connect server:
-
-.. code-block:: bash
-
-    bin/pyspark --remote sc://localhost
-
+    bin/pyspark --remote "local[*]" --packages 
org.apache.spark:spark-connect_2.12:3.4.0
diff --git a/python/pyspark/context.py b/python/pyspark/context.py
index f6e74493105..c3f18127121 100644
--- a/python/pyspark/context.py
+++ b/python/pyspark/context.py
@@ -179,7 +179,7 @@ class SparkContext:
         udf_profiler_cls: Type[UDFBasicProfiler] = UDFBasicProfiler,
         memory_profiler_cls: Type[MemoryProfiler] = MemoryProfiler,
     ):
-        if "SPARK_REMOTE" in os.environ and "SPARK_TESTING" not in os.environ:
+        if "SPARK_REMOTE" in os.environ and "SPARK_LOCAL_REMOTE" not in 
os.environ:
             raise RuntimeError(
                 "Remote client cannot create a SparkContext. Create 
SparkSession instead."
             )
diff --git a/python/pyspark/sql/connect/catalog.py 
b/python/pyspark/sql/connect/catalog.py
index 4c310ca57cd..790a1b8c000 100644
--- a/python/pyspark/sql/connect/catalog.py
+++ b/python/pyspark/sql/connect/catalog.py
@@ -18,7 +18,6 @@ from typing import List, Optional, TYPE_CHECKING
 
 import pandas as pd
 
-from pyspark import SparkContext, SparkConf
 from pyspark.sql.types import StructType
 from pyspark.sql.connect import DataFrame
 from pyspark.sql.catalog import (
@@ -324,16 +323,12 @@ def _test() -> None:
         import pyspark.sql.connect.catalog
 
         globs = pyspark.sql.connect.catalog.__dict__.copy()
-        # Works around to create a regular Spark session
-        sc = SparkContext("local[4]", "sql.connect.catalog tests", 
conf=SparkConf())
-        globs["_spark"] = PySparkSession(
-            sc, options={"spark.app.name": "sql.connect.catalog tests"}
+        globs["spark"] = (
+            PySparkSession.builder.appName("sql.connect.catalog tests")
+            .remote("local[4]")
+            .getOrCreate()
         )
 
-        # Creates a remote Spark session.
-        os.environ["SPARK_REMOTE"] = "sc://localhost"
-        globs["spark"] = 
PySparkSession.builder.remote("sc://localhost").getOrCreate()
-
         # TODO(SPARK-41612): Support Catalog.isCached
         # TODO(SPARK-41600): Support Catalog.cacheTable
         del pyspark.sql.connect.catalog.Catalog.clearCache.__doc__
@@ -348,8 +343,8 @@ def _test() -> None:
             | doctest.NORMALIZE_WHITESPACE
             | doctest.IGNORE_EXCEPTION_DETAIL,
         )
-        globs["_spark"].stop()
         globs["spark"].stop()
+
         if failure_count:
             sys.exit(-1)
     else:
diff --git a/python/pyspark/sql/connect/column.py 
b/python/pyspark/sql/connect/column.py
index 931cd453740..d26283571cc 100644
--- a/python/pyspark/sql/connect/column.py
+++ b/python/pyspark/sql/connect/column.py
@@ -28,7 +28,6 @@ from typing import (
     Optional,
 )
 
-from pyspark import SparkContext, SparkConf
 from pyspark.sql.types import DataType
 from pyspark.sql.column import Column as PySparkColumn
 
@@ -434,13 +433,12 @@ def _test() -> None:
         import pyspark.sql.connect.column
 
         globs = pyspark.sql.connect.column.__dict__.copy()
-        # Works around to create a regular Spark session
-        sc = SparkContext("local[4]", "sql.connect.column tests", 
conf=SparkConf())
-        globs["_spark"] = PySparkSession(sc, options={"spark.app.name": 
"sql.connect.column tests"})
+        globs["spark"] = (
+            PySparkSession.builder.appName("sql.connect.column tests")
+            .remote("local[4]")
+            .getOrCreate()
+        )
 
-        # Creates a remote Spark session.
-        os.environ["SPARK_REMOTE"] = "sc://localhost"
-        globs["spark"] = 
PySparkSession.builder.remote("sc://localhost").getOrCreate()
         # Spark Connect has a different string representation for Column.
         del pyspark.sql.connect.column.Column.getItem.__doc__
 
@@ -456,7 +454,7 @@ def _test() -> None:
         )
 
         globs["spark"].stop()
-        globs["_spark"].stop()
+
         if failure_count:
             sys.exit(-1)
     else:
diff --git a/python/pyspark/sql/connect/dataframe.py 
b/python/pyspark/sql/connect/dataframe.py
index cf7b1596537..25f0cacb643 100644
--- a/python/pyspark/sql/connect/dataframe.py
+++ b/python/pyspark/sql/connect/dataframe.py
@@ -38,7 +38,7 @@ import json
 import warnings
 from collections.abc import Iterable
 
-from pyspark import _NoValue, SparkContext, SparkConf
+from pyspark import _NoValue
 from pyspark._globals import _NoValueType
 from pyspark.sql.types import StructType, Row
 
@@ -1532,12 +1532,6 @@ def _test() -> None:
         import pyspark.sql.connect.dataframe
 
         globs = pyspark.sql.connect.dataframe.__dict__.copy()
-        # Works around to create a regular Spark session
-        sc = SparkContext("local[4]", "sql.connect.dataframe tests", 
conf=SparkConf())
-        globs["_spark"] = PySparkSession(
-            sc, options={"spark.app.name": "sql.connect.dataframe tests"}
-        )
-
         # Spark Connect does not support RDD but the tests depend on them.
         del pyspark.sql.connect.dataframe.DataFrame.coalesce.__doc__
         del pyspark.sql.connect.dataframe.DataFrame.repartition.__doc__
@@ -1564,9 +1558,11 @@ def _test() -> None:
         # TODO(SPARK-41818): Support saveAsTable
         del pyspark.sql.connect.dataframe.DataFrame.write.__doc__
 
-        # Creates a remote Spark session.
-        os.environ["SPARK_REMOTE"] = "sc://localhost"
-        globs["spark"] = 
PySparkSession.builder.remote("sc://localhost").getOrCreate()
+        globs["spark"] = (
+            PySparkSession.builder.appName("sql.connect.dataframe tests")
+            .remote("local[4]")
+            .getOrCreate()
+        )
 
         (failure_count, test_count) = doctest.testmod(
             pyspark.sql.connect.dataframe,
@@ -1577,7 +1573,7 @@ def _test() -> None:
         )
 
         globs["spark"].stop()
-        globs["_spark"].stop()
+
         if failure_count:
             sys.exit(-1)
     else:
diff --git a/python/pyspark/sql/connect/functions.py 
b/python/pyspark/sql/connect/functions.py
index 356855581bd..993fd30e7f0 100644
--- a/python/pyspark/sql/connect/functions.py
+++ b/python/pyspark/sql/connect/functions.py
@@ -31,7 +31,6 @@ from typing import (
     cast,
 )
 
-from pyspark import SparkContext, SparkConf
 from pyspark.sql.connect.column import Column
 from pyspark.sql.connect.expressions import (
     CaseWhen,
@@ -2354,11 +2353,7 @@ def _test() -> None:
         import pyspark.sql.connect.functions
 
         globs = pyspark.sql.connect.functions.__dict__.copy()
-        # Works around to create a regular Spark session
-        sc = SparkContext("local[4]", "sql.connect.functions tests", 
conf=SparkConf())
-        globs["_spark"] = PySparkSession(
-            sc, options={"spark.app.name": "sql.connect.functions tests"}
-        )
+
         # Spark Connect does not support Spark Context but the test depends on 
that.
         del pyspark.sql.connect.functions.monotonically_increasing_id.__doc__
 
@@ -2407,9 +2402,11 @@ def _test() -> None:
         del pyspark.sql.connect.functions.map_zip_with.__doc__
         del pyspark.sql.connect.functions.posexplode.__doc__
 
-        # Creates a remote Spark session.
-        os.environ["SPARK_REMOTE"] = "sc://localhost"
-        globs["spark"] = 
PySparkSession.builder.remote("sc://localhost").getOrCreate()
+        globs["spark"] = (
+            PySparkSession.builder.appName("sql.connect.functions tests")
+            .remote("local[4]")
+            .getOrCreate()
+        )
 
         (failure_count, test_count) = doctest.testmod(
             pyspark.sql.connect.functions,
@@ -2420,7 +2417,7 @@ def _test() -> None:
         )
 
         globs["spark"].stop()
-        globs["_spark"].stop()
+
         if failure_count:
             sys.exit(-1)
     else:
diff --git a/python/pyspark/sql/connect/group.py 
b/python/pyspark/sql/connect/group.py
index df73f5b6fa2..3aa070ff8b6 100644
--- a/python/pyspark/sql/connect/group.py
+++ b/python/pyspark/sql/connect/group.py
@@ -226,7 +226,6 @@ def _test() -> None:
     import os
     import sys
     import doctest
-    from pyspark import SparkContext, SparkConf
     from pyspark.sql import SparkSession as PySparkSession
     from pyspark.testing.connectutils import should_test_connect, 
connect_requirement_message
 
@@ -236,21 +235,21 @@ def _test() -> None:
         import pyspark.sql.connect.group
 
         globs = pyspark.sql.connect.group.__dict__.copy()
-        # Works around to create a regular Spark session
-        sc = SparkContext("local[4]", "sql.connect.group tests", 
conf=SparkConf())
-        globs["_spark"] = PySparkSession(sc, options={"spark.app.name": 
"sql.connect.group tests"})
 
-        # Creates a remote Spark session.
-        os.environ["SPARK_REMOTE"] = "sc://localhost"
-        globs["spark"] = 
PySparkSession.builder.remote("sc://localhost").getOrCreate()
+        globs["spark"] = (
+            PySparkSession.builder.appName("sql.connect.group tests")
+            .remote("local[4]")
+            .getOrCreate()
+        )
 
         (failure_count, test_count) = doctest.testmod(
             pyspark.sql.connect.group,
             globs=globs,
             optionflags=doctest.ELLIPSIS | doctest.NORMALIZE_WHITESPACE | 
doctest.REPORT_NDIFF,
         )
-        globs["_spark"].stop()
+
         globs["spark"].stop()
+
         if failure_count:
             sys.exit(-1)
     else:
diff --git a/python/pyspark/sql/connect/readwriter.py 
b/python/pyspark/sql/connect/readwriter.py
index f3eab1a888e..80cecbc7f50 100644
--- a/python/pyspark/sql/connect/readwriter.py
+++ b/python/pyspark/sql/connect/readwriter.py
@@ -20,7 +20,6 @@ from typing import Dict
 from typing import Optional, Union, List, overload, Tuple, cast, Any
 from typing import TYPE_CHECKING
 
-from pyspark import SparkContext, SparkConf
 from pyspark.sql.connect.plan import Read, DataSource, LogicalPlan, 
WriteOperation
 from pyspark.sql.types import StructType
 from pyspark.sql.utils import to_str
@@ -497,11 +496,6 @@ def _test() -> None:
         import pyspark.sql.connect.readwriter
 
         globs = pyspark.sql.connect.readwriter.__dict__.copy()
-        # Works around to create a regular Spark session
-        sc = SparkContext("local[4]", "sql.connect.readwriter tests", 
conf=SparkConf())
-        globs["_spark"] = PySparkSession(
-            sc, options={"spark.app.name": "sql.connect.readwriter tests"}
-        )
 
         # TODO(SPARK-41817): Support reading with schema
         del pyspark.sql.connect.readwriter.DataFrameReader.load.__doc__
@@ -517,9 +511,11 @@ def _test() -> None:
         del pyspark.sql.connect.readwriter.DataFrameWriter.insertInto.__doc__
         del pyspark.sql.connect.readwriter.DataFrameWriter.saveAsTable.__doc__
 
-        # Creates a remote Spark session.
-        os.environ["SPARK_REMOTE"] = "sc://localhost"
-        globs["spark"] = 
PySparkSession.builder.remote("sc://localhost").getOrCreate()
+        globs["spark"] = (
+            PySparkSession.builder.appName("sql.connect.readwriter tests")
+            .remote("local[4]")
+            .getOrCreate()
+        )
 
         (failure_count, test_count) = doctest.testmod(
             pyspark.sql.connect.readwriter,
@@ -530,7 +526,7 @@ def _test() -> None:
         )
 
         globs["spark"].stop()
-        globs["_spark"].stop()
+
         if failure_count:
             sys.exit(-1)
     else:
diff --git a/python/pyspark/sql/connect/session.py 
b/python/pyspark/sql/connect/session.py
index bd6e8bd19f3..0314f41bff6 100644
--- a/python/pyspark/sql/connect/session.py
+++ b/python/pyspark/sql/connect/session.py
@@ -14,6 +14,9 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 #
+import os
+import warnings
+from distutils.version import LooseVersion
 from threading import RLock
 from collections.abc import Sized
 from functools import reduce
@@ -22,7 +25,8 @@ import numpy as np
 import pandas as pd
 import pyarrow as pa
 
-from pyspark import SparkContext, SparkConf
+from pyspark import SparkContext, SparkConf, __version__
+from pyspark.java_gateway import launch_gateway
 from pyspark.sql.session import classproperty, SparkSession as PySparkSession
 from pyspark.sql.types import (
     _infer_schema,
@@ -53,6 +57,7 @@ from typing import (
     TYPE_CHECKING,
 )
 
+
 if TYPE_CHECKING:
     from pyspark.sql.connect._typing import OptionalPrimitiveType
     from pyspark.sql.connect.catalog import Catalog
@@ -344,25 +349,37 @@ class SparkSession:
 
     catalog.__doc__ = PySparkSession.catalog.__doc__
 
+    def __del__(self) -> None:
+        try:
+            # Try its best to close.
+            self.client.close()
+        except Exception:
+            pass
+
     def stop(self) -> None:
+        # Stopping the session will only close the connection to the current 
session (and
+        # the life cycle of the session is maintained by the server),
+        # whereas the regular PySpark session immediately terminates the Spark 
Context
+        # itself, meaning that stopping all Spark sessions.
+        # It is controversial to follow the existing the regular Spark 
session's behavior
+        # specifically in Spark Connect the Spark Connect server is designed 
for
+        # multi-tenancy - the remote client side cannot just stop the server 
and stop
+        # other remote clients being used from other users.
         self.client.close()
 
-    stop.__doc__ = PySparkSession.stop.__doc__
-
-    # SparkConnect-specific API
-    @property
-    def client(self) -> "SparkConnectClient":
-        """
-        Gives access to the Spark Connect client. In normal cases this is not 
necessary to be used
-        and only relevant for testing.
-        Returns
-        -------
-        :class:`SparkConnectClient`
-        """
-        return self._client
+        if "SPARK_LOCAL_REMOTE" in os.environ:
+            # When local mode is in use, follow the regular Spark session's
+            # behavior by terminating the Spark Connect server,
+            # meaning that you can stop local mode, and restart the Spark 
Connect
+            # client with a different remote address.
+            active_session = PySparkSession.getActiveSession()
+            if active_session is not None:
+                active_session.stop()
+            with SparkContext._lock:
+                del os.environ["SPARK_LOCAL_REMOTE"]
+                del os.environ["SPARK_REMOTE"]
 
-    def register_udf(self, function: Any, return_type: Union[str, DataType]) 
-> str:
-        return self._client.register_udf(function, return_type)
+    stop.__doc__ = PySparkSession.stop.__doc__
 
     @classmethod
     def getActiveSession(cls) -> Any:
@@ -391,6 +408,138 @@ class SparkSession:
     def version(self) -> str:
         raise NotImplementedError("version() is not implemented.")
 
+    # SparkConnect-specific API
+    @property
+    def client(self) -> "SparkConnectClient":
+        """
+        Gives access to the Spark Connect client. In normal cases this is not 
necessary to be used
+        and only relevant for testing.
+        Returns
+        -------
+        :class:`SparkConnectClient`
+        """
+        return self._client
+
+    def register_udf(self, function: Any, return_type: Union[str, DataType]) 
-> str:
+        return self._client.register_udf(function, return_type)
+
+    @staticmethod
+    def _start_connect_server(master: str) -> None:
+        """
+        Starts the Spark Connect server given the master.
+
+        At the high level, there are two cases. The first case is when you 
locally build
+        Apache Spark, and run ``SparkSession.builder.remote("local")``:
+
+        1. This method automatically finds the jars for Spark Connect (because 
the jars for
+          Spark Connect are not bundled in the regular Apache Spark release).
+
+        2. Temporarily remove all states for Spark Connect, for example, 
``SPARK_REMOTE``
+          environment variable.
+
+            2.1. If we're in PySpark application submission, e.g., 
``bin/spark-submit app.py``
+              starts a JVM (without Spark Context) first, and adds the Spark 
Connect server jars
+              into the current class loader. Otherwise, Spark Context with 
``spark.plugins``
+              cannot be initialized because the JVM is already running without 
the jars in
+              the class path before executing this Python process for driver 
side.
+
+        3. Starts a regular Spark session that automatically starts a Spark 
Connect server
+          with JVM (if it is not up) via ``spark.plugins`` feature.
+
+        The second case is when you use Apache Spark release:
+
+        1. Users must specify either the jars or package, e.g., ``--packages
+          org.apache.spark:spark-connect_2.12:3.4.0``. The jars or packages 
would be specified
+          in SparkSubmit automatically. This method does not do anything 
related to this.
+
+        2. Temporarily remove all states for Spark Connect, for example, 
``SPARK_REMOTE``
+          environment variable. It does not do anything for PySpark 
application submission as
+          well because jars or packages were already specified before 
executing this Python
+          process for driver side.
+
+        3. Starts a regular Spark session that automatically starts a Spark 
Connect server
+          with JVM via ``spark.plugins`` feature.
+        """
+        session = PySparkSession._instantiatedSession
+        if session is None or session._sc._jsc is None:
+            conf = SparkConf()
+            # Check if we're using unreleased version that is in development.
+            # Also checks SPARK_TESTING for RC versions.
+            is_dev_mode = (
+                "dev" in LooseVersion(__version__).version or "SPARK_TESTING" 
in os.environ
+            )
+            connect_jar = None
+
+            if is_dev_mode:
+                from pyspark.testing.utils import search_jar
+
+                # Note that, in production, spark.jars.packages configuration 
should be
+                # set by users. Here we're automatically searching the jars 
locally built.
+                connect_jar = search_jar(
+                    "connector/connect/server", "spark-connect-assembly-", 
"spark-connect"
+                )
+                if connect_jar is not None:
+                    origin_jars = conf.get("spark.jars")
+                    if origin_jars is not None:
+                        conf.set("spark.jars", f"{origin_jars},{connect_jar}")
+                    else:
+                        conf.set("spark.jars", connect_jar)
+                else:
+                    warnings.warn(
+                        "Attempted to automatically find the Spark Connect 
jars because "
+                        "'SPARK_TESTING' environment variable is set, or the 
current PySpark "
+                        f"version is dev version ({__version__}). However, the 
jar was not found. "
+                        "Manually locate the jars and specify them, e.g., 
'spark.jars' "
+                        "configuration."
+                    )
+
+            conf.set("spark.master", master)
+
+            connect_plugin = "org.apache.spark.sql.connect.SparkConnectPlugin"
+            origin_plugin = conf.get("spark.plugins")
+            if origin_plugin is not None:
+                conf.set("spark.plugins", f"{origin_plugin},{connect_plugin}")
+            else:
+                conf.set("spark.plugins", connect_plugin)
+
+            origin_remote = os.environ.get("SPARK_REMOTE", None)
+            origin_args = os.environ.get("PYSPARK_SUBMIT_ARGS", None)
+            try:
+                if origin_remote is not None:
+                    # So SparkSubmit thinks no remote is set in order to
+                    # start the regular PySpark session.
+                    del os.environ["SPARK_REMOTE"]
+
+                # PySpark shell launches Py4J server from Python.
+                # Remove "--remote" option specified, and use plain arguments.
+                # NOTE that this is not used in regular PySpark application
+                # submission because JVM at this point is already running.
+                os.environ["PYSPARK_SUBMIT_ARGS"] = '"--name" "PySparkShell" 
"pyspark-shell"'
+
+                if is_dev_mode and connect_jar is not None:
+                    # In the case of Python application submission, JVM is 
already up.
+                    # Therefore, we should manually manipulate the classpath 
in that case.
+                    # Otherwise, the jars are added but the driver would not 
be able to
+                    # find the server jars.
+                    with SparkContext._lock:
+                        if not SparkContext._gateway:
+                            SparkContext._gateway = launch_gateway(conf)
+                            SparkContext._jvm = SparkContext._gateway.jvm
+                            
SparkContext._jvm.PythonSQLUtils.addJarToCurrentClassLoader(connect_jar)
+
+                # The regular PySpark session is registered as an active 
session
+                # so would not be garbage-collected.
+                PySparkSession(SparkContext.getOrCreate(conf))
+            finally:
+                if origin_args is not None:
+                    os.environ["PYSPARK_SUBMIT_ARGS"] = origin_args
+                else:
+                    del os.environ["PYSPARK_SUBMIT_ARGS"]
+                if origin_remote is not None:
+                    os.environ["SPARK_REMOTE"] = origin_remote
+        else:
+            raise RuntimeError("There should not be an existing Spark Session 
or Spark Context.")
+
 
 SparkSession.__doc__ = PySparkSession.__doc__
 
@@ -408,16 +557,12 @@ def _test() -> None:
         import pyspark.sql.connect.session
 
         globs = pyspark.sql.connect.session.__dict__.copy()
-        # Works around to create a regular Spark session
-        sc = SparkContext("local[4]", "sql.connect.session tests", 
conf=SparkConf())
-        globs["_spark"] = PySparkSession(
-            sc, options={"spark.app.name": "sql.connect.session tests"}
+        globs["spark"] = (
+            PySparkSession.builder.appName("sql.connect.session tests")
+            .remote("local[4]")
+            .getOrCreate()
         )
 
-        # Creates a remote Spark session.
-        os.environ["SPARK_REMOTE"] = "sc://localhost"
-        globs["spark"] = 
PySparkSession.builder.remote("sc://localhost").getOrCreate()
-
         # Uses PySpark session to test builder.
         globs["SparkSession"] = PySparkSession
         # Spark Connect does not support to set master together.
@@ -438,7 +583,7 @@ def _test() -> None:
         )
 
         globs["spark"].stop()
-        globs["_spark"].stop()
+
         if failure_count:
             sys.exit(-1)
     else:
diff --git a/python/pyspark/sql/connect/window.py 
b/python/pyspark/sql/connect/window.py
index 315d74709fb..d504b19f800 100644
--- a/python/pyspark/sql/connect/window.py
+++ b/python/pyspark/sql/connect/window.py
@@ -18,7 +18,6 @@
 import sys
 from typing import TYPE_CHECKING, Union, Sequence, List, Optional
 
-from pyspark import SparkContext, SparkConf
 from pyspark.sql.connect.column import Column
 from pyspark.sql.connect.expressions import (
     ColumnReference,
@@ -240,13 +239,11 @@ def _test() -> None:
         import pyspark.sql.connect.window
 
         globs = pyspark.sql.connect.window.__dict__.copy()
-        # Works around to create a regular Spark session
-        sc = SparkContext("local[4]", "sql.connect.window tests", 
conf=SparkConf())
-        globs["_spark"] = PySparkSession(sc, options={"spark.app.name": 
"sql.connect.window tests"})
-
-        # Creates a remote Spark session.
-        os.environ["SPARK_REMOTE"] = "sc://localhost"
-        globs["spark"] = 
PySparkSession.builder.remote("sc://localhost").getOrCreate()
+        globs["spark"] = (
+            PySparkSession.builder.appName("sql.connect.window tests")
+            .remote("local[4]")
+            .getOrCreate()
+        )
 
         (failure_count, test_count) = doctest.testmod(
             pyspark.sql.connect.window,
@@ -257,7 +254,7 @@ def _test() -> None:
         )
 
         globs["spark"].stop()
-        globs["_spark"].stop()
+
         if failure_count:
             sys.exit(-1)
     else:
diff --git a/python/pyspark/sql/session.py b/python/pyspark/sql/session.py
index 1e4e6f5e3ad..2d6b3f88e1f 100644
--- a/python/pyspark/sql/session.py
+++ b/python/pyspark/sql/session.py
@@ -270,10 +270,13 @@ class SparkSession(SparkConversionMixin):
                             % self._options.get("spark.master", 
os.environ.get("MASTER"))
                         )
 
-                    if "SPARK_REMOTE" in os.environ and 
os.environ["SPARK_REMOTE"] != v:
+                    if ("SPARK_REMOTE" in os.environ and 
os.environ["SPARK_REMOTE"] != v) and (
+                        "SPARK_LOCAL_REMOTE" in os.environ and not 
v.startswith("local")
+                    ):
                         raise RuntimeError(
-                            "Only one Spark Connect client URL can be set; 
however, got a different"
-                            "URL [%s] from the existing [%s]" % 
(os.environ["SPARK_REMOTE"], v)
+                            "Only one Spark Connect client URL can be set; 
however, got a "
+                            "different URL [%s] from the existing [%s]"
+                            % (os.environ["SPARK_REMOTE"], v)
                         )
 
             with self._lock:
@@ -415,31 +418,39 @@ class SparkSession(SparkConversionMixin):
             True
             """
             from pyspark.context import SparkContext
+            from pyspark.conf import SparkConf
 
             opts = dict(self._options)
-            if "SPARK_REMOTE" in os.environ or "spark.remote" in opts:
-                with SparkContext._lock:
-                    from pyspark.sql.connect.session import SparkSession as 
RemoteSparkSession
-
-                    if (
-                        SparkContext._active_spark_context is None
-                        and SparkSession._instantiatedSession is None
-                    ):
-                        url = opts.get("spark.remote", 
os.environ.get("SPARK_REMOTE"))
-                        os.environ["SPARK_REMOTE"] = url
-                        opts["spark.remote"] = url
-                        return 
RemoteSparkSession.builder.config(map=opts).getOrCreate()
-                    elif "SPARK_TESTING" not in os.environ:
-                        raise RuntimeError(
-                            "Cannot start a remote Spark session because there 
"
-                            "is a regular Spark Connect already running."
-                        )
-
-                # Cannot reach here in production. Test-only.
-                return 
RemoteSparkSession.builder.config(map=opts).getOrCreate()
 
             with self._lock:
-                from pyspark.conf import SparkConf
+                if "SPARK_REMOTE" in os.environ or "spark.remote" in opts:
+                    with SparkContext._lock:
+                        from pyspark.sql.connect.session import SparkSession 
as RemoteSparkSession
+
+                        if (
+                            SparkContext._active_spark_context is None
+                            and SparkSession._instantiatedSession is None
+                        ):
+                            url = opts.get("spark.remote", 
os.environ.get("SPARK_REMOTE"))
+
+                            if url.startswith("local"):
+                                os.environ["SPARK_LOCAL_REMOTE"] = "1"
+                                RemoteSparkSession._start_connect_server(url)
+                                url = "sc://localhost"
+
+                            os.environ["SPARK_REMOTE"] = url
+                            opts["spark.remote"] = url
+                            return 
RemoteSparkSession.builder.config(map=opts).getOrCreate()
+                        elif "SPARK_LOCAL_REMOTE" in os.environ:
+                            url = "sc://localhost"
+                            os.environ["SPARK_REMOTE"] = url
+                            opts["spark.remote"] = url
+                            return 
RemoteSparkSession.builder.config(map=opts).getOrCreate()
+                        else:
+                            raise RuntimeError(
+                                "Cannot start a remote Spark session because 
there "
+                                "is a regular Spark Connect already running."
+                            )
 
                 session = SparkSession._instantiatedSession
                 if session is None or session._sc._jsc is None:
diff --git a/python/pyspark/sql/tests/connect/test_parity_catalog.py 
b/python/pyspark/sql/tests/connect/test_parity_catalog.py
index e5135edb8cf..3da702198ac 100644
--- a/python/pyspark/sql/tests/connect/test_parity_catalog.py
+++ b/python/pyspark/sql/tests/connect/test_parity_catalog.py
@@ -16,31 +16,12 @@
 #
 
 import unittest
-import os
 
-from pyspark.sql import SparkSession
 from pyspark.sql.tests.test_catalog import CatalogTestsMixin
-from pyspark.testing.connectutils import should_test_connect, 
connect_requirement_message
-from pyspark.testing.sqlutils import ReusedSQLTestCase
+from pyspark.testing.connectutils import ReusedConnectTestCase
 
 
-@unittest.skipIf(not should_test_connect, connect_requirement_message)
-class CatalogParityTests(CatalogTestsMixin, ReusedSQLTestCase):
-    @classmethod
-    def setUpClass(cls):
-        super(CatalogParityTests, cls).setUpClass()
-        cls._spark = cls.spark  # Assign existing Spark session to run the 
server
-        # Sets the remote address. Now, we create a remote Spark Session.
-        # Note that this is only allowed in testing.
-        os.environ["SPARK_REMOTE"] = "sc://localhost"
-        cls.spark = SparkSession.builder.remote("sc://localhost").getOrCreate()
-
-    @classmethod
-    def tearDownClass(cls):
-        super(CatalogParityTests, cls).tearDownClass()
-        cls._spark.stop()
-        del os.environ["SPARK_REMOTE"]
-
+class CatalogParityTests(CatalogTestsMixin, ReusedConnectTestCase):
     # TODO(SPARK-41612): Support Catalog.isCached
     # TODO(SPARK-41600): Support Catalog.cacheTable
     # TODO(SPARK-41623): Support Catalog.uncacheTable
diff --git a/python/pyspark/sql/tests/connect/test_parity_dataframe.py 
b/python/pyspark/sql/tests/connect/test_parity_dataframe.py
index f5d6db38e8d..ed7c85d32ab 100644
--- a/python/pyspark/sql/tests/connect/test_parity_dataframe.py
+++ b/python/pyspark/sql/tests/connect/test_parity_dataframe.py
@@ -16,31 +16,12 @@
 #
 
 import unittest
-import os
 
-from pyspark.sql import SparkSession
 from pyspark.sql.tests.test_dataframe import DataFrameTestsMixin
-from pyspark.testing.connectutils import should_test_connect, 
connect_requirement_message
-from pyspark.testing.sqlutils import ReusedSQLTestCase
-
-
-@unittest.skipIf(not should_test_connect, connect_requirement_message)
-class DataFrameParityTests(DataFrameTestsMixin, ReusedSQLTestCase):
-    @classmethod
-    def setUpClass(cls):
-        super(DataFrameParityTests, cls).setUpClass()
-        cls._spark = cls.spark  # Assign existing Spark session to run the 
server
-        # Sets the remote address. Now, we create a remote Spark Session.
-        # Note that this is only allowed in testing.
-        os.environ["SPARK_REMOTE"] = "sc://localhost"
-        cls.spark = SparkSession.builder.remote("sc://localhost").getOrCreate()
-
-    @classmethod
-    def tearDownClass(cls):
-        super(DataFrameParityTests, cls).tearDownClass()
-        cls._spark.stop()
-        del os.environ["SPARK_REMOTE"]
+from pyspark.testing.connectutils import ReusedConnectTestCase
 
+
+class DataFrameParityTests(DataFrameTestsMixin, ReusedConnectTestCase):
     # TODO(SPARK-41612): support Catalog.isCached
     @unittest.skip("Fails in Spark Connect, should enable.")
     def test_cache(self):
diff --git a/python/pyspark/sql/tests/connect/test_parity_functions.py 
b/python/pyspark/sql/tests/connect/test_parity_functions.py
index 3e46e1caa3e..89571aff6d6 100644
--- a/python/pyspark/sql/tests/connect/test_parity_functions.py
+++ b/python/pyspark/sql/tests/connect/test_parity_functions.py
@@ -16,34 +16,12 @@
 #
 
 import unittest
-import os
 
-from pyspark.sql import SparkSession
 from pyspark.sql.tests.test_functions import FunctionsTestsMixin
-from pyspark.testing.connectutils import should_test_connect, 
connect_requirement_message
-from pyspark.testing.sqlutils import ReusedSQLTestCase
-
-
-@unittest.skipIf(not should_test_connect, connect_requirement_message)
-class FunctionsParityTests(ReusedSQLTestCase, FunctionsTestsMixin):
-    @classmethod
-    def setUpClass(cls):
-        from pyspark.sql.connect.session import SparkSession as 
RemoteSparkSession
-
-        super(FunctionsParityTests, cls).setUpClass()
-        cls._spark = cls.spark  # Assign existing Spark session to run the 
server
-        # Sets the remote address. Now, we create a remote Spark Session.
-        # Note that this is only allowed in testing.
-        os.environ["SPARK_REMOTE"] = "sc://localhost"
-        cls.spark = SparkSession.builder.remote("sc://localhost").getOrCreate()
-        assert isinstance(cls.spark, RemoteSparkSession)
-
-    @classmethod
-    def tearDownClass(cls):
-        super(FunctionsParityTests, cls).tearDownClass()
-        cls.spark = cls._spark.stop()
-        del os.environ["SPARK_REMOTE"]
+from pyspark.testing.connectutils import ReusedConnectTestCase
 
+
+class FunctionsParityTests(FunctionsTestsMixin, ReusedConnectTestCase):
     # TODO(SPARK-41897): Parity in Error types between pyspark and connect 
functions
     @unittest.skip("Fails in Spark Connect, should enable.")
     def test_assert_true(self):
diff --git a/python/pyspark/sql/tests/test_dataframe.py 
b/python/pyspark/sql/tests/test_dataframe.py
index bb87393b362..4d0d5541067 100644
--- a/python/pyspark/sql/tests/test_dataframe.py
+++ b/python/pyspark/sql/tests/test_dataframe.py
@@ -531,7 +531,7 @@ class DataFrameTestsMixin:
 
         # Type check
         self.assertRaises(TypeError, self.df.withColumns, ["key"])
-        self.assertRaises(AssertionError, self.df.withColumns)
+        self.assertRaises(Exception, self.df.withColumns)
 
     def test_generic_hints(self):
         from pyspark.sql import DataFrame
diff --git a/python/pyspark/testing/connectutils.py 
b/python/pyspark/testing/connectutils.py
index bec116b5f79..0ebce47bd54 100644
--- a/python/pyspark/testing/connectutils.py
+++ b/python/pyspark/testing/connectutils.py
@@ -14,17 +14,22 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 #
+import shutil
+import tempfile
 import typing
 import os
 import functools
 import unittest
 
+from pyspark import Row
 from pyspark.testing.sqlutils import (
     have_pandas,
     have_pyarrow,
     pandas_requirement_message,
     pyarrow_requirement_message,
+    SQLTestUtils,
 )
+from pyspark.sql.session import SparkSession as PySparkSession
 
 
 grpc_requirement_message = None
@@ -154,3 +159,23 @@ class PlanOnlyTestFixture(unittest.TestCase):
         cls.connect.drop_hook("range")
         cls.connect.drop_hook("sql")
         cls.connect.drop_hook("with_plan")
+
+
+@unittest.skipIf(not should_test_connect, connect_requirement_message)
+class ReusedConnectTestCase(unittest.TestCase, SQLTestUtils):
+    """
+    Spark Connect version of 
:class:`pyspark.testing.sqlutils.ReusedSQLTestCase`.
+    """
+
+    @classmethod
+    def setUpClass(cls):
+        cls.spark = 
PySparkSession.builder.appName(cls.__name__).remote("local[4]").getOrCreate()
+        cls.tempdir = tempfile.NamedTemporaryFile(delete=False)
+        os.unlink(cls.tempdir.name)
+        cls.testData = [Row(key=i, value=str(i)) for i in range(100)]
+        cls.df = cls.spark.createDataFrame(cls.testData)
+
+    @classmethod
+    def tearDownClass(cls):
+        shutil.rmtree(cls.tempdir.name, ignore_errors=True)
+        cls.spark.stop()
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/api/python/PythonSQLUtils.scala 
b/sql/core/src/main/scala/org/apache/spark/sql/api/python/PythonSQLUtils.scala
index 70474f4d5c4..196377cce2a 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/api/python/PythonSQLUtils.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/api/python/PythonSQLUtils.scala
@@ -39,6 +39,7 @@ import org.apache.spark.sql.execution.arrow.ArrowConverters
 import org.apache.spark.sql.execution.python.EvaluatePython
 import org.apache.spark.sql.internal.SQLConf
 import org.apache.spark.sql.types.{DataType, StructType}
+import org.apache.spark.util.{MutableURLClassLoader, Utils}
 
 private[sql] object PythonSQLUtils extends Logging {
   private def withInternalRowPickler(f: Pickler => Array[Byte]): Array[Byte] = 
{
@@ -126,6 +127,19 @@ private[sql] object PythonSQLUtils extends Logging {
     deserializer(internalRow)
   }
 
+  /**
+   * Internal-only helper for Spark Connect's local mode. This is only used for
+   * local development, not for production. This method should not be used in
+   * production code path.
+   */
+  def addJarToCurrentClassLoader(path: String): Unit = {
+    Utils.getContextOrSparkClassLoader match {
+      case cl: MutableURLClassLoader => cl.addURL(Utils.resolveURI(path).toURL)
+      case cl => logWarning(
+        s"Unsupported class loader $cl will not update jars in the thread 
class loader.")
+    }
+  }
+
   def castTimestampNTZToLong(c: Column): Column = 
Column(CastTimestampNTZToLong(c.expr))
 
   def ewm(e: Column, alpha: Double, ignoreNA: Boolean): Column =


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org


Reply via email to