This is an automated email from the ASF dual-hosted git repository. gurwls223 pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push: new 56c7cf33929 [SPARK-41933][CONNECT] Provide local mode that automatically starts the server 56c7cf33929 is described below commit 56c7cf33929d7d42b7d299c0bb7e895963241214 Author: Hyukjin Kwon <gurwls...@apache.org> AuthorDate: Sun Jan 8 16:49:15 2023 +0900 [SPARK-41933][CONNECT] Provide local mode that automatically starts the server ### What changes were proposed in this pull request? This PR proposes local mode for Spark Connect. It automatically starts a Spark session (with bypassing `local*` master string) that launches the Spark Connect server, which introduces two user-facing changes below. Notice that local mode exactly follows the regular PySpark session's stop behavior by terminating the server (whereas non-local mode would not close the server and other sessions). See also the newly added comments for `pyspark.sql.connect.SparkSession.stop`. #### Local build of Apache Spark (for developers) Automatically finds the jars for Spark Connect (because the jars for Spark Connect are not bundled in the regular Apache Spark release). - PySpark shell ```bash pyspark --remote local ``` - PySpark application submission ```bash spark-submit --remote "local[4]" app.py ``` - Use it as a Python library ```python from pyspark.sql import SparkSession SparkSession.builder.remote("local[*]").getOrCreate() ``` #### Official release of Apache Spark (for end-users) Users must specify jars or packages. Jars aren't automatically searched. - PySpark shell ```bash pyspark --packages org.apache.spark:spark-connect_2.12:3.4.0 --remote local ``` - PySpark application submission ```bash spark-submit --packages org.apache.spark:spark-connect_2.12:3.4.0 --remote "local[4]" app.py ``` - Use it as a Python library ```python from pyspark.sql import SparkSession SparkSession.builder.config( "spark.jars.packages", "org.apache.spark:spark-connect_2.12:3.4.0" ).remote("local[*]").getOrCreate() ``` ### Why are the changes needed? In order to provide an easier mode to try Spark Connect for both developers and end-users. ### Does this PR introduce _any_ user-facing change? No to end users because Spark Connect has not been released. To the dev, yes. See the examples above. ### How was this patch tested? Unittests were refactored to use/test this feature (that also deduplicated the codes). Closes #39441 from HyukjinKwon/SPARK-41933. Authored-by: Hyukjin Kwon <gurwls...@apache.org> Signed-off-by: Hyukjin Kwon <gurwls...@apache.org> --- python/docs/source/development/testing.rst | 20 +-- python/pyspark/context.py | 2 +- python/pyspark/sql/connect/catalog.py | 15 +- python/pyspark/sql/connect/column.py | 14 +- python/pyspark/sql/connect/dataframe.py | 18 +- python/pyspark/sql/connect/functions.py | 17 +- python/pyspark/sql/connect/group.py | 15 +- python/pyspark/sql/connect/readwriter.py | 16 +- python/pyspark/sql/connect/session.py | 195 ++++++++++++++++++--- python/pyspark/sql/connect/window.py | 15 +- python/pyspark/sql/session.py | 59 ++++--- .../sql/tests/connect/test_parity_catalog.py | 23 +-- .../sql/tests/connect/test_parity_dataframe.py | 25 +-- .../sql/tests/connect/test_parity_functions.py | 28 +-- python/pyspark/sql/tests/test_dataframe.py | 2 +- python/pyspark/testing/connectutils.py | 25 +++ .../spark/sql/api/python/PythonSQLUtils.scala | 14 ++ 17 files changed, 302 insertions(+), 201 deletions(-) diff --git a/python/docs/source/development/testing.rst b/python/docs/source/development/testing.rst index 0262c318cd6..e8255cab8c8 100644 --- a/python/docs/source/development/testing.rst +++ b/python/docs/source/development/testing.rst @@ -82,26 +82,14 @@ you should regenerate Python Protobuf client by running ``dev/connect-gen-protos Running PySpark Shell with Python Client ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ -To run Spark Connect server you locally built: +For Apache Spark you locally built: .. code-block:: bash - bin/spark-shell \ - --jars `ls connector/connect/target/**/spark-connect*SNAPSHOT.jar | paste -sd ',' -` \ - --conf spark.plugins=org.apache.spark.sql.connect.SparkConnectPlugin + bin/pyspark --remote "local[*]" -To run the Spark Connect server from the Apache Spark release: +For the Apache Spark release: .. code-block:: bash - bin/spark-shell \ - --packages org.apache.spark:spark-connect_2.12:3.4.0 \ - --conf spark.plugins=org.apache.spark.sql.connect.SparkConnectPlugin - - -To run the PySpark Shell with the client for the Spark Connect server: - -.. code-block:: bash - - bin/pyspark --remote sc://localhost - + bin/pyspark --remote "local[*]" --packages org.apache.spark:spark-connect_2.12:3.4.0 diff --git a/python/pyspark/context.py b/python/pyspark/context.py index f6e74493105..c3f18127121 100644 --- a/python/pyspark/context.py +++ b/python/pyspark/context.py @@ -179,7 +179,7 @@ class SparkContext: udf_profiler_cls: Type[UDFBasicProfiler] = UDFBasicProfiler, memory_profiler_cls: Type[MemoryProfiler] = MemoryProfiler, ): - if "SPARK_REMOTE" in os.environ and "SPARK_TESTING" not in os.environ: + if "SPARK_REMOTE" in os.environ and "SPARK_LOCAL_REMOTE" not in os.environ: raise RuntimeError( "Remote client cannot create a SparkContext. Create SparkSession instead." ) diff --git a/python/pyspark/sql/connect/catalog.py b/python/pyspark/sql/connect/catalog.py index 4c310ca57cd..790a1b8c000 100644 --- a/python/pyspark/sql/connect/catalog.py +++ b/python/pyspark/sql/connect/catalog.py @@ -18,7 +18,6 @@ from typing import List, Optional, TYPE_CHECKING import pandas as pd -from pyspark import SparkContext, SparkConf from pyspark.sql.types import StructType from pyspark.sql.connect import DataFrame from pyspark.sql.catalog import ( @@ -324,16 +323,12 @@ def _test() -> None: import pyspark.sql.connect.catalog globs = pyspark.sql.connect.catalog.__dict__.copy() - # Works around to create a regular Spark session - sc = SparkContext("local[4]", "sql.connect.catalog tests", conf=SparkConf()) - globs["_spark"] = PySparkSession( - sc, options={"spark.app.name": "sql.connect.catalog tests"} + globs["spark"] = ( + PySparkSession.builder.appName("sql.connect.catalog tests") + .remote("local[4]") + .getOrCreate() ) - # Creates a remote Spark session. - os.environ["SPARK_REMOTE"] = "sc://localhost" - globs["spark"] = PySparkSession.builder.remote("sc://localhost").getOrCreate() - # TODO(SPARK-41612): Support Catalog.isCached # TODO(SPARK-41600): Support Catalog.cacheTable del pyspark.sql.connect.catalog.Catalog.clearCache.__doc__ @@ -348,8 +343,8 @@ def _test() -> None: | doctest.NORMALIZE_WHITESPACE | doctest.IGNORE_EXCEPTION_DETAIL, ) - globs["_spark"].stop() globs["spark"].stop() + if failure_count: sys.exit(-1) else: diff --git a/python/pyspark/sql/connect/column.py b/python/pyspark/sql/connect/column.py index 931cd453740..d26283571cc 100644 --- a/python/pyspark/sql/connect/column.py +++ b/python/pyspark/sql/connect/column.py @@ -28,7 +28,6 @@ from typing import ( Optional, ) -from pyspark import SparkContext, SparkConf from pyspark.sql.types import DataType from pyspark.sql.column import Column as PySparkColumn @@ -434,13 +433,12 @@ def _test() -> None: import pyspark.sql.connect.column globs = pyspark.sql.connect.column.__dict__.copy() - # Works around to create a regular Spark session - sc = SparkContext("local[4]", "sql.connect.column tests", conf=SparkConf()) - globs["_spark"] = PySparkSession(sc, options={"spark.app.name": "sql.connect.column tests"}) + globs["spark"] = ( + PySparkSession.builder.appName("sql.connect.column tests") + .remote("local[4]") + .getOrCreate() + ) - # Creates a remote Spark session. - os.environ["SPARK_REMOTE"] = "sc://localhost" - globs["spark"] = PySparkSession.builder.remote("sc://localhost").getOrCreate() # Spark Connect has a different string representation for Column. del pyspark.sql.connect.column.Column.getItem.__doc__ @@ -456,7 +454,7 @@ def _test() -> None: ) globs["spark"].stop() - globs["_spark"].stop() + if failure_count: sys.exit(-1) else: diff --git a/python/pyspark/sql/connect/dataframe.py b/python/pyspark/sql/connect/dataframe.py index cf7b1596537..25f0cacb643 100644 --- a/python/pyspark/sql/connect/dataframe.py +++ b/python/pyspark/sql/connect/dataframe.py @@ -38,7 +38,7 @@ import json import warnings from collections.abc import Iterable -from pyspark import _NoValue, SparkContext, SparkConf +from pyspark import _NoValue from pyspark._globals import _NoValueType from pyspark.sql.types import StructType, Row @@ -1532,12 +1532,6 @@ def _test() -> None: import pyspark.sql.connect.dataframe globs = pyspark.sql.connect.dataframe.__dict__.copy() - # Works around to create a regular Spark session - sc = SparkContext("local[4]", "sql.connect.dataframe tests", conf=SparkConf()) - globs["_spark"] = PySparkSession( - sc, options={"spark.app.name": "sql.connect.dataframe tests"} - ) - # Spark Connect does not support RDD but the tests depend on them. del pyspark.sql.connect.dataframe.DataFrame.coalesce.__doc__ del pyspark.sql.connect.dataframe.DataFrame.repartition.__doc__ @@ -1564,9 +1558,11 @@ def _test() -> None: # TODO(SPARK-41818): Support saveAsTable del pyspark.sql.connect.dataframe.DataFrame.write.__doc__ - # Creates a remote Spark session. - os.environ["SPARK_REMOTE"] = "sc://localhost" - globs["spark"] = PySparkSession.builder.remote("sc://localhost").getOrCreate() + globs["spark"] = ( + PySparkSession.builder.appName("sql.connect.dataframe tests") + .remote("local[4]") + .getOrCreate() + ) (failure_count, test_count) = doctest.testmod( pyspark.sql.connect.dataframe, @@ -1577,7 +1573,7 @@ def _test() -> None: ) globs["spark"].stop() - globs["_spark"].stop() + if failure_count: sys.exit(-1) else: diff --git a/python/pyspark/sql/connect/functions.py b/python/pyspark/sql/connect/functions.py index 356855581bd..993fd30e7f0 100644 --- a/python/pyspark/sql/connect/functions.py +++ b/python/pyspark/sql/connect/functions.py @@ -31,7 +31,6 @@ from typing import ( cast, ) -from pyspark import SparkContext, SparkConf from pyspark.sql.connect.column import Column from pyspark.sql.connect.expressions import ( CaseWhen, @@ -2354,11 +2353,7 @@ def _test() -> None: import pyspark.sql.connect.functions globs = pyspark.sql.connect.functions.__dict__.copy() - # Works around to create a regular Spark session - sc = SparkContext("local[4]", "sql.connect.functions tests", conf=SparkConf()) - globs["_spark"] = PySparkSession( - sc, options={"spark.app.name": "sql.connect.functions tests"} - ) + # Spark Connect does not support Spark Context but the test depends on that. del pyspark.sql.connect.functions.monotonically_increasing_id.__doc__ @@ -2407,9 +2402,11 @@ def _test() -> None: del pyspark.sql.connect.functions.map_zip_with.__doc__ del pyspark.sql.connect.functions.posexplode.__doc__ - # Creates a remote Spark session. - os.environ["SPARK_REMOTE"] = "sc://localhost" - globs["spark"] = PySparkSession.builder.remote("sc://localhost").getOrCreate() + globs["spark"] = ( + PySparkSession.builder.appName("sql.connect.functions tests") + .remote("local[4]") + .getOrCreate() + ) (failure_count, test_count) = doctest.testmod( pyspark.sql.connect.functions, @@ -2420,7 +2417,7 @@ def _test() -> None: ) globs["spark"].stop() - globs["_spark"].stop() + if failure_count: sys.exit(-1) else: diff --git a/python/pyspark/sql/connect/group.py b/python/pyspark/sql/connect/group.py index df73f5b6fa2..3aa070ff8b6 100644 --- a/python/pyspark/sql/connect/group.py +++ b/python/pyspark/sql/connect/group.py @@ -226,7 +226,6 @@ def _test() -> None: import os import sys import doctest - from pyspark import SparkContext, SparkConf from pyspark.sql import SparkSession as PySparkSession from pyspark.testing.connectutils import should_test_connect, connect_requirement_message @@ -236,21 +235,21 @@ def _test() -> None: import pyspark.sql.connect.group globs = pyspark.sql.connect.group.__dict__.copy() - # Works around to create a regular Spark session - sc = SparkContext("local[4]", "sql.connect.group tests", conf=SparkConf()) - globs["_spark"] = PySparkSession(sc, options={"spark.app.name": "sql.connect.group tests"}) - # Creates a remote Spark session. - os.environ["SPARK_REMOTE"] = "sc://localhost" - globs["spark"] = PySparkSession.builder.remote("sc://localhost").getOrCreate() + globs["spark"] = ( + PySparkSession.builder.appName("sql.connect.group tests") + .remote("local[4]") + .getOrCreate() + ) (failure_count, test_count) = doctest.testmod( pyspark.sql.connect.group, globs=globs, optionflags=doctest.ELLIPSIS | doctest.NORMALIZE_WHITESPACE | doctest.REPORT_NDIFF, ) - globs["_spark"].stop() + globs["spark"].stop() + if failure_count: sys.exit(-1) else: diff --git a/python/pyspark/sql/connect/readwriter.py b/python/pyspark/sql/connect/readwriter.py index f3eab1a888e..80cecbc7f50 100644 --- a/python/pyspark/sql/connect/readwriter.py +++ b/python/pyspark/sql/connect/readwriter.py @@ -20,7 +20,6 @@ from typing import Dict from typing import Optional, Union, List, overload, Tuple, cast, Any from typing import TYPE_CHECKING -from pyspark import SparkContext, SparkConf from pyspark.sql.connect.plan import Read, DataSource, LogicalPlan, WriteOperation from pyspark.sql.types import StructType from pyspark.sql.utils import to_str @@ -497,11 +496,6 @@ def _test() -> None: import pyspark.sql.connect.readwriter globs = pyspark.sql.connect.readwriter.__dict__.copy() - # Works around to create a regular Spark session - sc = SparkContext("local[4]", "sql.connect.readwriter tests", conf=SparkConf()) - globs["_spark"] = PySparkSession( - sc, options={"spark.app.name": "sql.connect.readwriter tests"} - ) # TODO(SPARK-41817): Support reading with schema del pyspark.sql.connect.readwriter.DataFrameReader.load.__doc__ @@ -517,9 +511,11 @@ def _test() -> None: del pyspark.sql.connect.readwriter.DataFrameWriter.insertInto.__doc__ del pyspark.sql.connect.readwriter.DataFrameWriter.saveAsTable.__doc__ - # Creates a remote Spark session. - os.environ["SPARK_REMOTE"] = "sc://localhost" - globs["spark"] = PySparkSession.builder.remote("sc://localhost").getOrCreate() + globs["spark"] = ( + PySparkSession.builder.appName("sql.connect.readwriter tests") + .remote("local[4]") + .getOrCreate() + ) (failure_count, test_count) = doctest.testmod( pyspark.sql.connect.readwriter, @@ -530,7 +526,7 @@ def _test() -> None: ) globs["spark"].stop() - globs["_spark"].stop() + if failure_count: sys.exit(-1) else: diff --git a/python/pyspark/sql/connect/session.py b/python/pyspark/sql/connect/session.py index bd6e8bd19f3..0314f41bff6 100644 --- a/python/pyspark/sql/connect/session.py +++ b/python/pyspark/sql/connect/session.py @@ -14,6 +14,9 @@ # See the License for the specific language governing permissions and # limitations under the License. # +import os +import warnings +from distutils.version import LooseVersion from threading import RLock from collections.abc import Sized from functools import reduce @@ -22,7 +25,8 @@ import numpy as np import pandas as pd import pyarrow as pa -from pyspark import SparkContext, SparkConf +from pyspark import SparkContext, SparkConf, __version__ +from pyspark.java_gateway import launch_gateway from pyspark.sql.session import classproperty, SparkSession as PySparkSession from pyspark.sql.types import ( _infer_schema, @@ -53,6 +57,7 @@ from typing import ( TYPE_CHECKING, ) + if TYPE_CHECKING: from pyspark.sql.connect._typing import OptionalPrimitiveType from pyspark.sql.connect.catalog import Catalog @@ -344,25 +349,37 @@ class SparkSession: catalog.__doc__ = PySparkSession.catalog.__doc__ + def __del__(self) -> None: + try: + # Try its best to close. + self.client.close() + except Exception: + pass + def stop(self) -> None: + # Stopping the session will only close the connection to the current session (and + # the life cycle of the session is maintained by the server), + # whereas the regular PySpark session immediately terminates the Spark Context + # itself, meaning that stopping all Spark sessions. + # It is controversial to follow the existing the regular Spark session's behavior + # specifically in Spark Connect the Spark Connect server is designed for + # multi-tenancy - the remote client side cannot just stop the server and stop + # other remote clients being used from other users. self.client.close() - stop.__doc__ = PySparkSession.stop.__doc__ - - # SparkConnect-specific API - @property - def client(self) -> "SparkConnectClient": - """ - Gives access to the Spark Connect client. In normal cases this is not necessary to be used - and only relevant for testing. - Returns - ------- - :class:`SparkConnectClient` - """ - return self._client + if "SPARK_LOCAL_REMOTE" in os.environ: + # When local mode is in use, follow the regular Spark session's + # behavior by terminating the Spark Connect server, + # meaning that you can stop local mode, and restart the Spark Connect + # client with a different remote address. + active_session = PySparkSession.getActiveSession() + if active_session is not None: + active_session.stop() + with SparkContext._lock: + del os.environ["SPARK_LOCAL_REMOTE"] + del os.environ["SPARK_REMOTE"] - def register_udf(self, function: Any, return_type: Union[str, DataType]) -> str: - return self._client.register_udf(function, return_type) + stop.__doc__ = PySparkSession.stop.__doc__ @classmethod def getActiveSession(cls) -> Any: @@ -391,6 +408,138 @@ class SparkSession: def version(self) -> str: raise NotImplementedError("version() is not implemented.") + # SparkConnect-specific API + @property + def client(self) -> "SparkConnectClient": + """ + Gives access to the Spark Connect client. In normal cases this is not necessary to be used + and only relevant for testing. + Returns + ------- + :class:`SparkConnectClient` + """ + return self._client + + def register_udf(self, function: Any, return_type: Union[str, DataType]) -> str: + return self._client.register_udf(function, return_type) + + @staticmethod + def _start_connect_server(master: str) -> None: + """ + Starts the Spark Connect server given the master. + + At the high level, there are two cases. The first case is when you locally build + Apache Spark, and run ``SparkSession.builder.remote("local")``: + + 1. This method automatically finds the jars for Spark Connect (because the jars for + Spark Connect are not bundled in the regular Apache Spark release). + + 2. Temporarily remove all states for Spark Connect, for example, ``SPARK_REMOTE`` + environment variable. + + 2.1. If we're in PySpark application submission, e.g., ``bin/spark-submit app.py`` + starts a JVM (without Spark Context) first, and adds the Spark Connect server jars + into the current class loader. Otherwise, Spark Context with ``spark.plugins`` + cannot be initialized because the JVM is already running without the jars in + the class path before executing this Python process for driver side. + + 3. Starts a regular Spark session that automatically starts a Spark Connect server + with JVM (if it is not up) via ``spark.plugins`` feature. + + The second case is when you use Apache Spark release: + + 1. Users must specify either the jars or package, e.g., ``--packages + org.apache.spark:spark-connect_2.12:3.4.0``. The jars or packages would be specified + in SparkSubmit automatically. This method does not do anything related to this. + + 2. Temporarily remove all states for Spark Connect, for example, ``SPARK_REMOTE`` + environment variable. It does not do anything for PySpark application submission as + well because jars or packages were already specified before executing this Python + process for driver side. + + 3. Starts a regular Spark session that automatically starts a Spark Connect server + with JVM via ``spark.plugins`` feature. + """ + session = PySparkSession._instantiatedSession + if session is None or session._sc._jsc is None: + conf = SparkConf() + # Check if we're using unreleased version that is in development. + # Also checks SPARK_TESTING for RC versions. + is_dev_mode = ( + "dev" in LooseVersion(__version__).version or "SPARK_TESTING" in os.environ + ) + connect_jar = None + + if is_dev_mode: + from pyspark.testing.utils import search_jar + + # Note that, in production, spark.jars.packages configuration should be + # set by users. Here we're automatically searching the jars locally built. + connect_jar = search_jar( + "connector/connect/server", "spark-connect-assembly-", "spark-connect" + ) + if connect_jar is not None: + origin_jars = conf.get("spark.jars") + if origin_jars is not None: + conf.set("spark.jars", f"{origin_jars},{connect_jar}") + else: + conf.set("spark.jars", connect_jar) + else: + warnings.warn( + "Attempted to automatically find the Spark Connect jars because " + "'SPARK_TESTING' environment variable is set, or the current PySpark " + f"version is dev version ({__version__}). However, the jar was not found. " + "Manually locate the jars and specify them, e.g., 'spark.jars' " + "configuration." + ) + + conf.set("spark.master", master) + + connect_plugin = "org.apache.spark.sql.connect.SparkConnectPlugin" + origin_plugin = conf.get("spark.plugins") + if origin_plugin is not None: + conf.set("spark.plugins", f"{origin_plugin},{connect_plugin}") + else: + conf.set("spark.plugins", connect_plugin) + + origin_remote = os.environ.get("SPARK_REMOTE", None) + origin_args = os.environ.get("PYSPARK_SUBMIT_ARGS", None) + try: + if origin_remote is not None: + # So SparkSubmit thinks no remote is set in order to + # start the regular PySpark session. + del os.environ["SPARK_REMOTE"] + + # PySpark shell launches Py4J server from Python. + # Remove "--remote" option specified, and use plain arguments. + # NOTE that this is not used in regular PySpark application + # submission because JVM at this point is already running. + os.environ["PYSPARK_SUBMIT_ARGS"] = '"--name" "PySparkShell" "pyspark-shell"' + + if is_dev_mode and connect_jar is not None: + # In the case of Python application submission, JVM is already up. + # Therefore, we should manually manipulate the classpath in that case. + # Otherwise, the jars are added but the driver would not be able to + # find the server jars. + with SparkContext._lock: + if not SparkContext._gateway: + SparkContext._gateway = launch_gateway(conf) + SparkContext._jvm = SparkContext._gateway.jvm + SparkContext._jvm.PythonSQLUtils.addJarToCurrentClassLoader(connect_jar) + + # The regular PySpark session is registered as an active session + # so would not be garbage-collected. + PySparkSession(SparkContext.getOrCreate(conf)) + finally: + if origin_args is not None: + os.environ["PYSPARK_SUBMIT_ARGS"] = origin_args + else: + del os.environ["PYSPARK_SUBMIT_ARGS"] + if origin_remote is not None: + os.environ["SPARK_REMOTE"] = origin_remote + else: + raise RuntimeError("There should not be an existing Spark Session or Spark Context.") + SparkSession.__doc__ = PySparkSession.__doc__ @@ -408,16 +557,12 @@ def _test() -> None: import pyspark.sql.connect.session globs = pyspark.sql.connect.session.__dict__.copy() - # Works around to create a regular Spark session - sc = SparkContext("local[4]", "sql.connect.session tests", conf=SparkConf()) - globs["_spark"] = PySparkSession( - sc, options={"spark.app.name": "sql.connect.session tests"} + globs["spark"] = ( + PySparkSession.builder.appName("sql.connect.session tests") + .remote("local[4]") + .getOrCreate() ) - # Creates a remote Spark session. - os.environ["SPARK_REMOTE"] = "sc://localhost" - globs["spark"] = PySparkSession.builder.remote("sc://localhost").getOrCreate() - # Uses PySpark session to test builder. globs["SparkSession"] = PySparkSession # Spark Connect does not support to set master together. @@ -438,7 +583,7 @@ def _test() -> None: ) globs["spark"].stop() - globs["_spark"].stop() + if failure_count: sys.exit(-1) else: diff --git a/python/pyspark/sql/connect/window.py b/python/pyspark/sql/connect/window.py index 315d74709fb..d504b19f800 100644 --- a/python/pyspark/sql/connect/window.py +++ b/python/pyspark/sql/connect/window.py @@ -18,7 +18,6 @@ import sys from typing import TYPE_CHECKING, Union, Sequence, List, Optional -from pyspark import SparkContext, SparkConf from pyspark.sql.connect.column import Column from pyspark.sql.connect.expressions import ( ColumnReference, @@ -240,13 +239,11 @@ def _test() -> None: import pyspark.sql.connect.window globs = pyspark.sql.connect.window.__dict__.copy() - # Works around to create a regular Spark session - sc = SparkContext("local[4]", "sql.connect.window tests", conf=SparkConf()) - globs["_spark"] = PySparkSession(sc, options={"spark.app.name": "sql.connect.window tests"}) - - # Creates a remote Spark session. - os.environ["SPARK_REMOTE"] = "sc://localhost" - globs["spark"] = PySparkSession.builder.remote("sc://localhost").getOrCreate() + globs["spark"] = ( + PySparkSession.builder.appName("sql.connect.window tests") + .remote("local[4]") + .getOrCreate() + ) (failure_count, test_count) = doctest.testmod( pyspark.sql.connect.window, @@ -257,7 +254,7 @@ def _test() -> None: ) globs["spark"].stop() - globs["_spark"].stop() + if failure_count: sys.exit(-1) else: diff --git a/python/pyspark/sql/session.py b/python/pyspark/sql/session.py index 1e4e6f5e3ad..2d6b3f88e1f 100644 --- a/python/pyspark/sql/session.py +++ b/python/pyspark/sql/session.py @@ -270,10 +270,13 @@ class SparkSession(SparkConversionMixin): % self._options.get("spark.master", os.environ.get("MASTER")) ) - if "SPARK_REMOTE" in os.environ and os.environ["SPARK_REMOTE"] != v: + if ("SPARK_REMOTE" in os.environ and os.environ["SPARK_REMOTE"] != v) and ( + "SPARK_LOCAL_REMOTE" in os.environ and not v.startswith("local") + ): raise RuntimeError( - "Only one Spark Connect client URL can be set; however, got a different" - "URL [%s] from the existing [%s]" % (os.environ["SPARK_REMOTE"], v) + "Only one Spark Connect client URL can be set; however, got a " + "different URL [%s] from the existing [%s]" + % (os.environ["SPARK_REMOTE"], v) ) with self._lock: @@ -415,31 +418,39 @@ class SparkSession(SparkConversionMixin): True """ from pyspark.context import SparkContext + from pyspark.conf import SparkConf opts = dict(self._options) - if "SPARK_REMOTE" in os.environ or "spark.remote" in opts: - with SparkContext._lock: - from pyspark.sql.connect.session import SparkSession as RemoteSparkSession - - if ( - SparkContext._active_spark_context is None - and SparkSession._instantiatedSession is None - ): - url = opts.get("spark.remote", os.environ.get("SPARK_REMOTE")) - os.environ["SPARK_REMOTE"] = url - opts["spark.remote"] = url - return RemoteSparkSession.builder.config(map=opts).getOrCreate() - elif "SPARK_TESTING" not in os.environ: - raise RuntimeError( - "Cannot start a remote Spark session because there " - "is a regular Spark Connect already running." - ) - - # Cannot reach here in production. Test-only. - return RemoteSparkSession.builder.config(map=opts).getOrCreate() with self._lock: - from pyspark.conf import SparkConf + if "SPARK_REMOTE" in os.environ or "spark.remote" in opts: + with SparkContext._lock: + from pyspark.sql.connect.session import SparkSession as RemoteSparkSession + + if ( + SparkContext._active_spark_context is None + and SparkSession._instantiatedSession is None + ): + url = opts.get("spark.remote", os.environ.get("SPARK_REMOTE")) + + if url.startswith("local"): + os.environ["SPARK_LOCAL_REMOTE"] = "1" + RemoteSparkSession._start_connect_server(url) + url = "sc://localhost" + + os.environ["SPARK_REMOTE"] = url + opts["spark.remote"] = url + return RemoteSparkSession.builder.config(map=opts).getOrCreate() + elif "SPARK_LOCAL_REMOTE" in os.environ: + url = "sc://localhost" + os.environ["SPARK_REMOTE"] = url + opts["spark.remote"] = url + return RemoteSparkSession.builder.config(map=opts).getOrCreate() + else: + raise RuntimeError( + "Cannot start a remote Spark session because there " + "is a regular Spark Connect already running." + ) session = SparkSession._instantiatedSession if session is None or session._sc._jsc is None: diff --git a/python/pyspark/sql/tests/connect/test_parity_catalog.py b/python/pyspark/sql/tests/connect/test_parity_catalog.py index e5135edb8cf..3da702198ac 100644 --- a/python/pyspark/sql/tests/connect/test_parity_catalog.py +++ b/python/pyspark/sql/tests/connect/test_parity_catalog.py @@ -16,31 +16,12 @@ # import unittest -import os -from pyspark.sql import SparkSession from pyspark.sql.tests.test_catalog import CatalogTestsMixin -from pyspark.testing.connectutils import should_test_connect, connect_requirement_message -from pyspark.testing.sqlutils import ReusedSQLTestCase +from pyspark.testing.connectutils import ReusedConnectTestCase -@unittest.skipIf(not should_test_connect, connect_requirement_message) -class CatalogParityTests(CatalogTestsMixin, ReusedSQLTestCase): - @classmethod - def setUpClass(cls): - super(CatalogParityTests, cls).setUpClass() - cls._spark = cls.spark # Assign existing Spark session to run the server - # Sets the remote address. Now, we create a remote Spark Session. - # Note that this is only allowed in testing. - os.environ["SPARK_REMOTE"] = "sc://localhost" - cls.spark = SparkSession.builder.remote("sc://localhost").getOrCreate() - - @classmethod - def tearDownClass(cls): - super(CatalogParityTests, cls).tearDownClass() - cls._spark.stop() - del os.environ["SPARK_REMOTE"] - +class CatalogParityTests(CatalogTestsMixin, ReusedConnectTestCase): # TODO(SPARK-41612): Support Catalog.isCached # TODO(SPARK-41600): Support Catalog.cacheTable # TODO(SPARK-41623): Support Catalog.uncacheTable diff --git a/python/pyspark/sql/tests/connect/test_parity_dataframe.py b/python/pyspark/sql/tests/connect/test_parity_dataframe.py index f5d6db38e8d..ed7c85d32ab 100644 --- a/python/pyspark/sql/tests/connect/test_parity_dataframe.py +++ b/python/pyspark/sql/tests/connect/test_parity_dataframe.py @@ -16,31 +16,12 @@ # import unittest -import os -from pyspark.sql import SparkSession from pyspark.sql.tests.test_dataframe import DataFrameTestsMixin -from pyspark.testing.connectutils import should_test_connect, connect_requirement_message -from pyspark.testing.sqlutils import ReusedSQLTestCase - - -@unittest.skipIf(not should_test_connect, connect_requirement_message) -class DataFrameParityTests(DataFrameTestsMixin, ReusedSQLTestCase): - @classmethod - def setUpClass(cls): - super(DataFrameParityTests, cls).setUpClass() - cls._spark = cls.spark # Assign existing Spark session to run the server - # Sets the remote address. Now, we create a remote Spark Session. - # Note that this is only allowed in testing. - os.environ["SPARK_REMOTE"] = "sc://localhost" - cls.spark = SparkSession.builder.remote("sc://localhost").getOrCreate() - - @classmethod - def tearDownClass(cls): - super(DataFrameParityTests, cls).tearDownClass() - cls._spark.stop() - del os.environ["SPARK_REMOTE"] +from pyspark.testing.connectutils import ReusedConnectTestCase + +class DataFrameParityTests(DataFrameTestsMixin, ReusedConnectTestCase): # TODO(SPARK-41612): support Catalog.isCached @unittest.skip("Fails in Spark Connect, should enable.") def test_cache(self): diff --git a/python/pyspark/sql/tests/connect/test_parity_functions.py b/python/pyspark/sql/tests/connect/test_parity_functions.py index 3e46e1caa3e..89571aff6d6 100644 --- a/python/pyspark/sql/tests/connect/test_parity_functions.py +++ b/python/pyspark/sql/tests/connect/test_parity_functions.py @@ -16,34 +16,12 @@ # import unittest -import os -from pyspark.sql import SparkSession from pyspark.sql.tests.test_functions import FunctionsTestsMixin -from pyspark.testing.connectutils import should_test_connect, connect_requirement_message -from pyspark.testing.sqlutils import ReusedSQLTestCase - - -@unittest.skipIf(not should_test_connect, connect_requirement_message) -class FunctionsParityTests(ReusedSQLTestCase, FunctionsTestsMixin): - @classmethod - def setUpClass(cls): - from pyspark.sql.connect.session import SparkSession as RemoteSparkSession - - super(FunctionsParityTests, cls).setUpClass() - cls._spark = cls.spark # Assign existing Spark session to run the server - # Sets the remote address. Now, we create a remote Spark Session. - # Note that this is only allowed in testing. - os.environ["SPARK_REMOTE"] = "sc://localhost" - cls.spark = SparkSession.builder.remote("sc://localhost").getOrCreate() - assert isinstance(cls.spark, RemoteSparkSession) - - @classmethod - def tearDownClass(cls): - super(FunctionsParityTests, cls).tearDownClass() - cls.spark = cls._spark.stop() - del os.environ["SPARK_REMOTE"] +from pyspark.testing.connectutils import ReusedConnectTestCase + +class FunctionsParityTests(FunctionsTestsMixin, ReusedConnectTestCase): # TODO(SPARK-41897): Parity in Error types between pyspark and connect functions @unittest.skip("Fails in Spark Connect, should enable.") def test_assert_true(self): diff --git a/python/pyspark/sql/tests/test_dataframe.py b/python/pyspark/sql/tests/test_dataframe.py index bb87393b362..4d0d5541067 100644 --- a/python/pyspark/sql/tests/test_dataframe.py +++ b/python/pyspark/sql/tests/test_dataframe.py @@ -531,7 +531,7 @@ class DataFrameTestsMixin: # Type check self.assertRaises(TypeError, self.df.withColumns, ["key"]) - self.assertRaises(AssertionError, self.df.withColumns) + self.assertRaises(Exception, self.df.withColumns) def test_generic_hints(self): from pyspark.sql import DataFrame diff --git a/python/pyspark/testing/connectutils.py b/python/pyspark/testing/connectutils.py index bec116b5f79..0ebce47bd54 100644 --- a/python/pyspark/testing/connectutils.py +++ b/python/pyspark/testing/connectutils.py @@ -14,17 +14,22 @@ # See the License for the specific language governing permissions and # limitations under the License. # +import shutil +import tempfile import typing import os import functools import unittest +from pyspark import Row from pyspark.testing.sqlutils import ( have_pandas, have_pyarrow, pandas_requirement_message, pyarrow_requirement_message, + SQLTestUtils, ) +from pyspark.sql.session import SparkSession as PySparkSession grpc_requirement_message = None @@ -154,3 +159,23 @@ class PlanOnlyTestFixture(unittest.TestCase): cls.connect.drop_hook("range") cls.connect.drop_hook("sql") cls.connect.drop_hook("with_plan") + + +@unittest.skipIf(not should_test_connect, connect_requirement_message) +class ReusedConnectTestCase(unittest.TestCase, SQLTestUtils): + """ + Spark Connect version of :class:`pyspark.testing.sqlutils.ReusedSQLTestCase`. + """ + + @classmethod + def setUpClass(cls): + cls.spark = PySparkSession.builder.appName(cls.__name__).remote("local[4]").getOrCreate() + cls.tempdir = tempfile.NamedTemporaryFile(delete=False) + os.unlink(cls.tempdir.name) + cls.testData = [Row(key=i, value=str(i)) for i in range(100)] + cls.df = cls.spark.createDataFrame(cls.testData) + + @classmethod + def tearDownClass(cls): + shutil.rmtree(cls.tempdir.name, ignore_errors=True) + cls.spark.stop() diff --git a/sql/core/src/main/scala/org/apache/spark/sql/api/python/PythonSQLUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/api/python/PythonSQLUtils.scala index 70474f4d5c4..196377cce2a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/api/python/PythonSQLUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/api/python/PythonSQLUtils.scala @@ -39,6 +39,7 @@ import org.apache.spark.sql.execution.arrow.ArrowConverters import org.apache.spark.sql.execution.python.EvaluatePython import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types.{DataType, StructType} +import org.apache.spark.util.{MutableURLClassLoader, Utils} private[sql] object PythonSQLUtils extends Logging { private def withInternalRowPickler(f: Pickler => Array[Byte]): Array[Byte] = { @@ -126,6 +127,19 @@ private[sql] object PythonSQLUtils extends Logging { deserializer(internalRow) } + /** + * Internal-only helper for Spark Connect's local mode. This is only used for + * local development, not for production. This method should not be used in + * production code path. + */ + def addJarToCurrentClassLoader(path: String): Unit = { + Utils.getContextOrSparkClassLoader match { + case cl: MutableURLClassLoader => cl.addURL(Utils.resolveURI(path).toURL) + case cl => logWarning( + s"Unsupported class loader $cl will not update jars in the thread class loader.") + } + } + def castTimestampNTZToLong(c: Column): Column = Column(CastTimestampNTZToLong(c.expr)) def ewm(e: Column, alpha: Double, ignoreNA: Boolean): Column = --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org