This is an automated email from the ASF dual-hosted git repository. gurwls223 pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push: new 9368a0f0c10 [SPARK-44694][PYTHON][CONNECT] Refactor active sessions and expose them as an API 9368a0f0c10 is described below commit 9368a0f0c1001fb6fd64799a2e744874b6cd27e4 Author: Hyukjin Kwon <gurwls...@apache.org> AuthorDate: Tue Aug 8 11:03:05 2023 +0900 [SPARK-44694][PYTHON][CONNECT] Refactor active sessions and expose them as an API ### What changes were proposed in this pull request? This PR proposes to (mostly) refactor all the internal workarounds to get the active session correctly. There are few things to note: - _PySpark without Spark Connect does not already support the hierarchy of active sessions_. With pinned thread mode (enabled by default), PySpark does map each Python thread to JVM thread, but the thread creation happens within gateway server, that does not respect the thread hierarchy. Therefore, this PR follows the exactly same behaviour. - New thread will not have an active thread by default. - Other behaviours are same as PySpark without Connect, see also https://github.com/apache/spark/pull/42367 - Since I am here, I piggiyback few documentation changes. We missed document `SparkSession.readStream`, `SparkSession.streams`, `SparkSession.udtf`, `SparkSession.conf` and `SparkSession.version` in Spark Connect. - The changes here are mostly refactoring that reuses existing unittests while I expose two methods: - `SparkSession.getActiveSession` (only for Spark Connect) - `SparkSession.active` (for both in PySpark) ### Why are the changes needed? For Spark Connect users to be able to play with active and default sessions in Python. ### Does this PR introduce _any_ user-facing change? Yes, it adds new API: - `SparkSession.getActiveSession` (only for Spark Connect) - `SparkSession.active` (for both in PySpark) ### How was this patch tested? Existing unittests should cover all. Closes #42371 from HyukjinKwon/SPARK-44694. Authored-by: Hyukjin Kwon <gurwls...@apache.org> Signed-off-by: Hyukjin Kwon <gurwls...@apache.org> --- .../source/reference/pyspark.sql/spark_session.rst | 1 + python/pyspark/errors/error_classes.py | 5 + python/pyspark/ml/connect/io_utils.py | 8 +- python/pyspark/ml/connect/tuning.py | 11 ++- python/pyspark/ml/torch/distributor.py | 3 +- python/pyspark/ml/util.py | 13 --- python/pyspark/pandas/utils.py | 7 +- python/pyspark/sql/connect/session.py | 107 ++++++++++++++------- python/pyspark/sql/connect/udf.py | 25 +++-- python/pyspark/sql/connect/udtf.py | 27 +++--- python/pyspark/sql/session.py | 65 +++++++++++-- .../sql/tests/connect/test_connect_basic.py | 4 +- python/pyspark/sql/utils.py | 18 ++++ 13 files changed, 197 insertions(+), 97 deletions(-) diff --git a/python/docs/source/reference/pyspark.sql/spark_session.rst b/python/docs/source/reference/pyspark.sql/spark_session.rst index c16ca4f162f..f25dbab5f6b 100644 --- a/python/docs/source/reference/pyspark.sql/spark_session.rst +++ b/python/docs/source/reference/pyspark.sql/spark_session.rst @@ -28,6 +28,7 @@ See also :class:`SparkSession`. .. autosummary:: :toctree: api/ + SparkSession.active SparkSession.builder.appName SparkSession.builder.config SparkSession.builder.enableHiveSupport diff --git a/python/pyspark/errors/error_classes.py b/python/pyspark/errors/error_classes.py index a534bc6deb4..24885e94d32 100644 --- a/python/pyspark/errors/error_classes.py +++ b/python/pyspark/errors/error_classes.py @@ -617,6 +617,11 @@ ERROR_CLASSES_JSON = """ "Argument `<arg_name>` should be a WindowSpec, got <arg_type>." ] }, + "NO_ACTIVE_OR_DEFAULT_SESSION" : { + "message" : [ + "No active or default Spark session found. Please create a new Spark session before running the code." + ] + }, "NO_ACTIVE_SESSION" : { "message" : [ "No active Spark session found. Please create a new Spark session before running the code." diff --git a/python/pyspark/ml/connect/io_utils.py b/python/pyspark/ml/connect/io_utils.py index 9a963086aaf..a09a244862c 100644 --- a/python/pyspark/ml/connect/io_utils.py +++ b/python/pyspark/ml/connect/io_utils.py @@ -23,7 +23,7 @@ import time from urllib.parse import urlparse from typing import Any, Dict, List from pyspark.ml.base import Params -from pyspark.ml.util import _get_active_session +from pyspark.sql import SparkSession from pyspark.sql.utils import is_remote @@ -34,7 +34,7 @@ _META_DATA_FILE_NAME = "metadata.json" def _copy_file_from_local_to_fs(local_path: str, dest_path: str) -> None: - session = _get_active_session(is_remote()) + session = SparkSession.active() if is_remote(): session.copyFromLocalToFs(local_path, dest_path) else: @@ -228,7 +228,7 @@ class ParamsReadWrite(Params): .. versionadded:: 3.5.0 """ - session = _get_active_session(is_remote()) + session = SparkSession.active() path_exist = True try: session.read.format("binaryFile").load(path).head() @@ -256,7 +256,7 @@ class ParamsReadWrite(Params): .. versionadded:: 3.5.0 """ - session = _get_active_session(is_remote()) + session = SparkSession.active() tmp_local_dir = tempfile.mkdtemp(prefix="pyspark_ml_model_") try: diff --git a/python/pyspark/ml/connect/tuning.py b/python/pyspark/ml/connect/tuning.py index 6d539933e1d..c22c31e84e8 100644 --- a/python/pyspark/ml/connect/tuning.py +++ b/python/pyspark/ml/connect/tuning.py @@ -178,11 +178,12 @@ def _parallelFitTasks( def get_single_task(index: int, param_map: Any) -> Callable[[], Tuple[int, float]]: def single_task() -> Tuple[int, float]: - # Active session is thread-local variable, in background thread the active session - # is not set, the following line sets it as the main thread active session. - active_session._jvm.SparkSession.setActiveSession( # type: ignore[union-attr] - active_session._jsparkSession # type: ignore[union-attr] - ) + if not is_remote(): + # Active session is thread-local variable, in background thread the active session + # is not set, the following line sets it as the main thread active session. + active_session._jvm.SparkSession.setActiveSession( # type: ignore[union-attr] + active_session._jsparkSession # type: ignore[union-attr] + ) model = estimator.fit(train, param_map) metric = evaluator.evaluate( diff --git a/python/pyspark/ml/torch/distributor.py b/python/pyspark/ml/torch/distributor.py index 2056803d61c..a4e79b1dcc1 100644 --- a/python/pyspark/ml/torch/distributor.py +++ b/python/pyspark/ml/torch/distributor.py @@ -49,7 +49,6 @@ from pyspark.ml.torch.log_communication import ( # type: ignore LogStreamingServer, ) from pyspark.ml.dl_util import FunctionPickler -from pyspark.ml.util import _get_active_session def _get_resources(session: SparkSession) -> Dict[str, ResourceInformation]: @@ -165,7 +164,7 @@ class Distributor: from pyspark.sql.utils import is_remote self.is_remote = is_remote() - self.spark = _get_active_session(self.is_remote) + self.spark = SparkSession.active() # indicate whether the server side is local mode self.is_spark_local_master = False diff --git a/python/pyspark/ml/util.py b/python/pyspark/ml/util.py index 2c90ff3cb7b..64676947017 100644 --- a/python/pyspark/ml/util.py +++ b/python/pyspark/ml/util.py @@ -747,16 +747,3 @@ def try_remote_functions(f: FuncT) -> FuncT: return f(*args, **kwargs) return cast(FuncT, wrapped) - - -def _get_active_session(is_remote: bool) -> SparkSession: - if not is_remote: - spark = SparkSession.getActiveSession() - else: - import pyspark.sql.connect.session - - spark = pyspark.sql.connect.session._active_spark_session # type: ignore[assignment] - - if spark is None: - raise RuntimeError("An active SparkSession is required for the distributor.") - return spark diff --git a/python/pyspark/pandas/utils.py b/python/pyspark/pandas/utils.py index c66b3359e77..55b9a57ef61 100644 --- a/python/pyspark/pandas/utils.py +++ b/python/pyspark/pandas/utils.py @@ -478,12 +478,7 @@ def is_testing() -> bool: def default_session() -> SparkSession: - if not is_remote(): - spark = SparkSession.getActiveSession() - else: - from pyspark.sql.connect.session import _active_spark_session - - spark = _active_spark_session # type: ignore[assignment] + spark = SparkSession.getActiveSession() if spark is None: spark = SparkSession.builder.appName("pandas-on-Spark").getOrCreate() diff --git a/python/pyspark/sql/connect/session.py b/python/pyspark/sql/connect/session.py index 9bba0db05e4..d75a30c561f 100644 --- a/python/pyspark/sql/connect/session.py +++ b/python/pyspark/sql/connect/session.py @@ -18,6 +18,7 @@ from pyspark.sql.connect.utils import check_dependencies check_dependencies(__name__) +import threading import os import warnings from collections.abc import Sized @@ -36,6 +37,7 @@ from typing import ( overload, Iterable, TYPE_CHECKING, + ClassVar, ) import numpy as np @@ -93,14 +95,13 @@ if TYPE_CHECKING: from pyspark.sql.connect.udtf import UDTFRegistration -# `_active_spark_session` stores the active spark connect session created by -# `SparkSession.builder.getOrCreate`. It is used by ML code. -# If sessions are created with `SparkSession.builder.create`, it stores -# The last created session -_active_spark_session = None - - class SparkSession: + # The active SparkSession for the current thread + _active_session: ClassVar[threading.local] = threading.local() + # Reference to the root SparkSession + _default_session: ClassVar[Optional["SparkSession"]] = None + _lock: ClassVar[RLock] = RLock() + class Builder: """Builder for :class:`SparkSession`.""" @@ -176,8 +177,6 @@ class SparkSession: ) def create(self) -> "SparkSession": - global _active_spark_session - has_channel_builder = self._channel_builder is not None has_spark_remote = "spark.remote" in self._options @@ -200,23 +199,26 @@ class SparkSession: assert spark_remote is not None session = SparkSession(connection=spark_remote) - _active_spark_session = session + SparkSession._set_default_and_active_session(session) return session def getOrCreate(self) -> "SparkSession": - global _active_spark_session - if _active_spark_session is not None: - return _active_spark_session - _active_spark_session = self.create() - return _active_spark_session + with SparkSession._lock: + session = SparkSession.getActiveSession() + if session is None: + session = SparkSession._default_session + if session is None: + session = self.create() + return session _client: SparkConnectClient @classproperty def builder(cls) -> Builder: - """Creates a :class:`Builder` for constructing a :class:`SparkSession`.""" return cls.Builder() + builder.__doc__ = PySparkSession.builder.__doc__ + def __init__(self, connection: Union[str, ChannelBuilder], userId: Optional[str] = None): """ Creates a new SparkSession for the Spark Connect interface. @@ -236,6 +238,38 @@ class SparkSession: self._client = SparkConnectClient(connection=connection, user_id=userId) self._session_id = self._client._session_id + @classmethod + def _set_default_and_active_session(cls, session: "SparkSession") -> None: + """ + Set the (global) default :class:`SparkSession`, and (thread-local) + active :class:`SparkSession` when they are not set yet. + """ + with cls._lock: + if cls._default_session is None: + cls._default_session = session + if getattr(cls._active_session, "session", None) is None: + cls._active_session.session = session + + @classmethod + def getActiveSession(cls) -> Optional["SparkSession"]: + return getattr(cls._active_session, "session", None) + + getActiveSession.__doc__ = PySparkSession.getActiveSession.__doc__ + + @classmethod + def active(cls) -> "SparkSession": + session = cls.getActiveSession() + if session is None: + session = cls._default_session + if session is None: + raise PySparkRuntimeError( + error_class="NO_ACTIVE_OR_DEFAULT_SESSION", + message_parameters={}, + ) + return session + + active.__doc__ = PySparkSession.active.__doc__ + def table(self, tableName: str) -> DataFrame: return self.read.table(tableName) @@ -251,6 +285,8 @@ class SparkSession: def readStream(self) -> "DataStreamReader": return DataStreamReader(self) + readStream.__doc__ = PySparkSession.readStream.__doc__ + def _inferSchemaFromList( self, data: Iterable[Any], names: Optional[List[str]] = None ) -> StructType: @@ -601,19 +637,20 @@ class SparkSession: # specifically in Spark Connect the Spark Connect server is designed for # multi-tenancy - the remote client side cannot just stop the server and stop # other remote clients being used from other users. - global _active_spark_session - self.client.close() - _active_spark_session = None - - if "SPARK_LOCAL_REMOTE" in os.environ: - # When local mode is in use, follow the regular Spark session's - # behavior by terminating the Spark Connect server, - # meaning that you can stop local mode, and restart the Spark Connect - # client with a different remote address. - active_session = PySparkSession.getActiveSession() - if active_session is not None: - active_session.stop() - with SparkContext._lock: + with SparkSession._lock: + self.client.close() + if self is SparkSession._default_session: + SparkSession._default_session = None + if self is getattr(SparkSession._active_session, "session", None): + SparkSession._active_session.session = None + + if "SPARK_LOCAL_REMOTE" in os.environ: + # When local mode is in use, follow the regular Spark session's + # behavior by terminating the Spark Connect server, + # meaning that you can stop local mode, and restart the Spark Connect + # client with a different remote address. + if PySparkSession._activeSession is not None: + PySparkSession._activeSession.stop() del os.environ["SPARK_LOCAL_REMOTE"] del os.environ["SPARK_CONNECT_MODE_ENABLED"] if "SPARK_REMOTE" in os.environ: @@ -628,20 +665,18 @@ class SparkSession: """ return self.client.is_closed - @classmethod - def getActiveSession(cls) -> Any: - raise PySparkNotImplementedError( - error_class="NOT_IMPLEMENTED", message_parameters={"feature": "getActiveSession()"} - ) - @property def conf(self) -> RuntimeConf: return RuntimeConf(self.client) + conf.__doc__ = PySparkSession.conf.__doc__ + @property def streams(self) -> "StreamingQueryManager": return StreamingQueryManager(self) + streams.__doc__ = PySparkSession.streams.__doc__ + def __getattr__(self, name: str) -> Any: if name in ["_jsc", "_jconf", "_jvm", "_jsparkSession"]: raise PySparkAttributeError( @@ -675,6 +710,8 @@ class SparkSession: assert result is not None return result + version.__doc__ = PySparkSession.version.__doc__ + @property def client(self) -> "SparkConnectClient": return self._client diff --git a/python/pyspark/sql/connect/udf.py b/python/pyspark/sql/connect/udf.py index 2d7e423d3d5..eb0541b9369 100644 --- a/python/pyspark/sql/connect/udf.py +++ b/python/pyspark/sql/connect/udf.py @@ -37,8 +37,7 @@ from pyspark.sql.connect.column import Column from pyspark.sql.connect.types import UnparsedDataType from pyspark.sql.types import DataType, StringType from pyspark.sql.udf import UDFRegistration as PySparkUDFRegistration -from pyspark.errors import PySparkTypeError - +from pyspark.errors import PySparkTypeError, PySparkRuntimeError if TYPE_CHECKING: from pyspark.sql.connect._typing import ( @@ -58,14 +57,20 @@ def _create_py_udf( from pyspark.sql.udf import _create_arrow_py_udf if useArrow is None: - from pyspark.sql.connect.session import _active_spark_session - - is_arrow_enabled = ( - False - if _active_spark_session is None - else _active_spark_session.conf.get("spark.sql.execution.pythonUDF.arrow.enabled") - == "true" - ) + is_arrow_enabled = False + try: + from pyspark.sql.connect.session import SparkSession + + session = SparkSession.active() + is_arrow_enabled = ( + str(session.conf.get("spark.sql.execution.pythonUDF.arrow.enabled")).lower() + == "true" + ) + except PySparkRuntimeError as e: + if e.error_class == "NO_ACTIVE_OR_DEFAULT_SESSION": + pass # Just uses the default if no session found. + else: + raise e else: is_arrow_enabled = useArrow diff --git a/python/pyspark/sql/connect/udtf.py b/python/pyspark/sql/connect/udtf.py index 5a95075a655..c8495626292 100644 --- a/python/pyspark/sql/connect/udtf.py +++ b/python/pyspark/sql/connect/udtf.py @@ -68,13 +68,20 @@ def _create_py_udtf( if useArrow is not None: arrow_enabled = useArrow else: - from pyspark.sql.connect.session import _active_spark_session + from pyspark.sql.connect.session import SparkSession arrow_enabled = False - if _active_spark_session is not None: - value = _active_spark_session.conf.get("spark.sql.execution.pythonUDTF.arrow.enabled") - if isinstance(value, str) and value.lower() == "true": - arrow_enabled = True + try: + session = SparkSession.active() + arrow_enabled = ( + str(session.conf.get("spark.sql.execution.pythonUDTF.arrow.enabled")).lower() + == "true" + ) + except PySparkRuntimeError as e: + if e.error_class == "NO_ACTIVE_OR_DEFAULT_SESSION": + pass # Just uses the default if no session found. + else: + raise e # Create a regular Python UDTF and check for invalid handler class. regular_udtf = _create_udtf(cls, returnType, name, PythonEvalType.SQL_TABLE_UDF, deterministic) @@ -160,17 +167,13 @@ class UserDefinedTableFunction: ) def __call__(self, *cols: "ColumnOrName") -> "DataFrame": + from pyspark.sql.connect.session import SparkSession from pyspark.sql.connect.dataframe import DataFrame - from pyspark.sql.connect.session import _active_spark_session - if _active_spark_session is None: - raise PySparkRuntimeError( - "An active SparkSession is required for " - "executing a Python user-defined table function." - ) + session = SparkSession.active() plan = self._build_common_inline_user_defined_table_function(*cols) - return DataFrame.withPlan(plan, _active_spark_session) + return DataFrame.withPlan(plan, session) def asNondeterministic(self) -> "UserDefinedTableFunction": self.deterministic = False diff --git a/python/pyspark/sql/session.py b/python/pyspark/sql/session.py index ede6318782e..9141051fdf8 100644 --- a/python/pyspark/sql/session.py +++ b/python/pyspark/sql/session.py @@ -64,8 +64,8 @@ from pyspark.sql.types import ( _from_numpy_type, ) from pyspark.errors.exceptions.captured import install_exception_handler -from pyspark.sql.utils import is_timestamp_ntz_preferred, to_str -from pyspark.errors import PySparkValueError, PySparkTypeError +from pyspark.sql.utils import is_timestamp_ntz_preferred, to_str, try_remote_session_classmethod +from pyspark.errors import PySparkValueError, PySparkTypeError, PySparkRuntimeError if TYPE_CHECKING: from pyspark.sql._typing import AtomicValue, RowLike, OptionalPrimitiveType @@ -500,7 +500,7 @@ class SparkSession(SparkConversionMixin): ).applyModifiableSettings(session._jsparkSession, self._options) return session - # SparkConnect-specific API + # Spark Connect-specific API def create(self) -> "SparkSession": """Creates a new SparkSession. Can only be used in the context of Spark Connect and will throw an exception otherwise. @@ -510,6 +510,10 @@ class SparkSession(SparkConversionMixin): Returns ------- :class:`SparkSession` + + Notes + ----- + This method will update the default and/or active session if they are not set. """ opts = dict(self._options) if "SPARK_REMOTE" in os.environ or "spark.remote" in opts: @@ -546,7 +550,11 @@ class SparkSession(SparkConversionMixin): # to Python 3.9.6 (https://github.com/python/cpython/pull/28838) @classproperty def builder(cls) -> Builder: - """Creates a :class:`Builder` for constructing a :class:`SparkSession`.""" + """Creates a :class:`Builder` for constructing a :class:`SparkSession`. + + .. versionchanged:: 3.4.0 + Supports Spark Connect. + """ return cls.Builder() _instantiatedSession: ClassVar[Optional["SparkSession"]] = None @@ -632,12 +640,16 @@ class SparkSession(SparkConversionMixin): return self.__class__(self._sc, self._jsparkSession.newSession()) @classmethod + @try_remote_session_classmethod def getActiveSession(cls) -> Optional["SparkSession"]: """ Returns the active :class:`SparkSession` for the current thread, returned by the builder .. versionadded:: 3.0.0 + .. versionchanged:: 3.5.0 + Supports Spark Connect. + Returns ------- :class:`SparkSession` @@ -667,6 +679,30 @@ class SparkSession(SparkConversionMixin): else: return None + @classmethod + @try_remote_session_classmethod + def active(cls) -> "SparkSession": + """ + Returns the active or default :class:`SparkSession` for the current thread, returned by + the builder. + + .. versionadded:: 3.5.0 + + Returns + ------- + :class:`SparkSession` + Spark session if an active or default session exists for the current thread. + """ + session = cls.getActiveSession() + if session is None: + session = cls._instantiatedSession + if session is None: + raise PySparkRuntimeError( + error_class="NO_ACTIVE_OR_DEFAULT_SESSION", + message_parameters={}, + ) + return session + @property def sparkContext(self) -> SparkContext: """ @@ -698,6 +734,9 @@ class SparkSession(SparkConversionMixin): .. versionadded:: 2.0.0 + .. versionchanged:: 3.4.0 + Supports Spark Connect. + Returns ------- str @@ -719,6 +758,9 @@ class SparkSession(SparkConversionMixin): .. versionadded:: 2.0.0 + .. versionchanged:: 3.4.0 + Supports Spark Connect. + Returns ------- :class:`pyspark.sql.conf.RuntimeConfig` @@ -726,7 +768,7 @@ class SparkSession(SparkConversionMixin): Examples -------- >>> spark.conf - <pyspark.sql.conf.RuntimeConfig object ...> + <pyspark...RuntimeConf...> Set a runtime configuration for the session @@ -805,6 +847,9 @@ class SparkSession(SparkConversionMixin): .. versionadded:: 3.5.0 + .. versionchanged:: 3.5.0 + Supports Spark Connect. + Returns ------- :class:`UDTFRegistration` @@ -1639,6 +1684,9 @@ class SparkSession(SparkConversionMixin): .. versionadded:: 2.0.0 + .. versionchanged:: 3.5.0 + Supports Spark Connect. + Notes ----- This API is evolving. @@ -1650,7 +1698,7 @@ class SparkSession(SparkConversionMixin): Examples -------- >>> spark.readStream - <pyspark.sql.streaming.readwriter.DataStreamReader object ...> + <pyspark...DataStreamReader object ...> The example below uses Rate source that generates rows continuously. After that, we operate a modulo by 3, and then write the stream out to the console. @@ -1672,6 +1720,9 @@ class SparkSession(SparkConversionMixin): .. versionadded:: 2.0.0 + .. versionchanged:: 3.5.0 + Supports Spark Connect. + Notes ----- This API is evolving. @@ -1683,7 +1734,7 @@ class SparkSession(SparkConversionMixin): Examples -------- >>> spark.streams - <pyspark.sql.streaming.query.StreamingQueryManager object ...> + <pyspark...StreamingQueryManager object ...> Get the list of active streaming queries diff --git a/python/pyspark/sql/tests/connect/test_connect_basic.py b/python/pyspark/sql/tests/connect/test_connect_basic.py index 065f1585a9f..0687fc9f313 100644 --- a/python/pyspark/sql/tests/connect/test_connect_basic.py +++ b/python/pyspark/sql/tests/connect/test_connect_basic.py @@ -3043,9 +3043,6 @@ class SparkConnectBasicTests(SparkConnectSQLTestCase): def test_unsupported_session_functions(self): # SPARK-41934: Disable unsupported functions. - with self.assertRaises(NotImplementedError): - RemoteSparkSession.getActiveSession() - with self.assertRaises(NotImplementedError): RemoteSparkSession.builder.enableHiveSupport() @@ -3331,6 +3328,7 @@ class SparkConnectSessionTests(ReusedConnectTestCase): spark.stop() def test_can_create_multiple_sessions_to_different_remotes(self): + self.spark.stop() self.assertIsNotNone(self.spark._client) # Creates a new remote session. other = PySparkSession.builder.remote("sc://other.remote:114/").create() diff --git a/python/pyspark/sql/utils.py b/python/pyspark/sql/utils.py index 8b520ed653f..d4f56fe822f 100644 --- a/python/pyspark/sql/utils.py +++ b/python/pyspark/sql/utils.py @@ -14,6 +14,7 @@ # See the License for the specific language governing permissions and # limitations under the License. # +import inspect import functools import os from typing import Any, Callable, Optional, Sequence, TYPE_CHECKING, cast, TypeVar, Union, Type @@ -258,6 +259,23 @@ def try_remote_observation(f: FuncT) -> FuncT: return cast(FuncT, wrapped) +def try_remote_session_classmethod(f: FuncT) -> FuncT: + """Mark API supported from Spark Connect.""" + + @functools.wraps(f) + def wrapped(*args: Any, **kwargs: Any) -> Any: + + if is_remote() and "PYSPARK_NO_NAMESPACE_SHARE" not in os.environ: + from pyspark.sql.connect.session import SparkSession # type: ignore[misc] + + assert inspect.isclass(args[0]) + return getattr(SparkSession, f.__name__)(*args[1:], **kwargs) + else: + return f(*args, **kwargs) + + return cast(FuncT, wrapped) + + def pyspark_column_op( func_name: str, left: "IndexOpsLike", right: Any, fillna: Any = None ) -> Union["SeriesOrIndex", None]: --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org