This is an automated email from the ASF dual-hosted git repository. gurwls223 pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push: new 7af4e358f3f [SPARK-44740][CONNECT] Support specifying `session_id` in SPARK_REMOTE connection string 7af4e358f3f is described below commit 7af4e358f3f4902cc9601e56c2662b8921a925d6 Author: Martin Grund <martin.gr...@databricks.com> AuthorDate: Thu Aug 10 08:41:13 2023 +0900 [SPARK-44740][CONNECT] Support specifying `session_id` in SPARK_REMOTE connection string ### What changes were proposed in this pull request? To support cross-language session sharing in Spark connect, we need to be able to inject the session ID into the connection string because on the server side, the client-provided session ID is used already together with the user id. ``` SparkSession.builder.remote("sc://localhost/;session_id=abcdefg").getOrCreate() ``` ### Why are the changes needed? ease of use ### Does this PR introduce _any_ user-facing change? Adds a way to configure the Spark Connect connection string with `session_id` ### How was this patch tested? Added UT for the parameter. Closes #42415 from grundprinzip/SPARK-44740. Authored-by: Martin Grund <martin.gr...@databricks.com> Signed-off-by: Hyukjin Kwon <gurwls...@apache.org> --- .../sql/connect/client/SparkConnectClient.scala | 22 ++++++++++++++-- .../connect/client/SparkConnectClientParser.scala | 3 +++ .../SparkConnectClientBuilderParseTestSuite.scala | 4 +++ .../connect/client/SparkConnectClientSuite.scala | 6 +++++ connector/connect/docs/client-connection-string.md | 11 ++++++++ python/pyspark/sql/connect/client/core.py | 30 +++++++++++++++++++--- .../sql/tests/connect/client/test_client.py | 7 +++++ .../sql/tests/connect/test_connect_basic.py | 18 ++++++++++++- 8 files changed, 94 insertions(+), 7 deletions(-) diff --git a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/connect/client/SparkConnectClient.scala b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/connect/client/SparkConnectClient.scala index a028df536cf..637499f090c 100644 --- a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/connect/client/SparkConnectClient.scala +++ b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/connect/client/SparkConnectClient.scala @@ -56,7 +56,7 @@ private[sql] class SparkConnectClient( // Generate a unique session ID for this client. This UUID must be unique to allow // concurrent Spark sessions of the same user. If the channel is closed, creating // a new client will create a new session ID. - private[sql] val sessionId: String = UUID.randomUUID.toString + private[sql] val sessionId: String = configuration.sessionId.getOrElse(UUID.randomUUID.toString) private[client] val artifactManager: ArtifactManager = { new ArtifactManager(configuration, sessionId, bstub, stub) @@ -432,6 +432,7 @@ object SparkConnectClient { val PARAM_USE_SSL = "use_ssl" val PARAM_TOKEN = "token" val PARAM_USER_AGENT = "user_agent" + val PARAM_SESSION_ID = "session_id" } private def verifyURI(uri: URI): Unit = { @@ -463,6 +464,21 @@ object SparkConnectClient { this } + def sessionId(value: String): Builder = { + try { + UUID.fromString(value).toString + } catch { + case e: IllegalArgumentException => + throw new IllegalArgumentException( + "Parameter value 'session_id' must be a valid UUID format.", + e) + } + _configuration = _configuration.copy(sessionId = Some(value)) + this + } + + def sessionId: Option[String] = _configuration.sessionId + def userAgent: String = _configuration.userAgent def option(key: String, value: String): Builder = { @@ -490,6 +506,7 @@ object SparkConnectClient { case URIParams.PARAM_TOKEN => token(value) case URIParams.PARAM_USE_SSL => if (java.lang.Boolean.valueOf(value)) enableSsl() else disableSsl() + case URIParams.PARAM_SESSION_ID => sessionId(value) case _ => option(key, value) } } @@ -576,7 +593,8 @@ object SparkConnectClient { userAgent: String = DEFAULT_USER_AGENT, retryPolicy: GrpcRetryHandler.RetryPolicy = GrpcRetryHandler.RetryPolicy(), useReattachableExecute: Boolean = true, - interceptors: List[ClientInterceptor] = List.empty) { + interceptors: List[ClientInterceptor] = List.empty, + sessionId: Option[String] = None) { def userContext: proto.UserContext = { val builder = proto.UserContext.newBuilder() diff --git a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/connect/client/SparkConnectClientParser.scala b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/connect/client/SparkConnectClientParser.scala index dda769dc2ad..f873e1045bf 100644 --- a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/connect/client/SparkConnectClientParser.scala +++ b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/connect/client/SparkConnectClientParser.scala @@ -71,6 +71,9 @@ private[sql] object SparkConnectClientParser { case "--user_agent" :: tail => val (value, remainder) = extract("--user_agent", tail) parse(remainder, builder.userAgent(value)) + case "--session_id" :: tail => + val (value, remainder) = extract("--session_id", tail) + parse(remainder, builder.sessionId(value)) case "--option" :: tail => if (args.isEmpty) { throw new IllegalArgumentException("--option requires a key-value pair") diff --git a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/SparkConnectClientBuilderParseTestSuite.scala b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/SparkConnectClientBuilderParseTestSuite.scala index 2c6886d0386..1dc1fd567ec 100644 --- a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/SparkConnectClientBuilderParseTestSuite.scala +++ b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/SparkConnectClientBuilderParseTestSuite.scala @@ -16,6 +16,8 @@ */ package org.apache.spark.sql.connect.client +import java.util.UUID + import org.apache.spark.sql.connect.client.util.ConnectFunSuite /** @@ -46,6 +48,7 @@ class SparkConnectClientBuilderParseTestSuite extends ConnectFunSuite { argumentTest("user_id", "U1238", _.userId.get) argumentTest("user_name", "alice", _.userName.get) argumentTest("user_agent", "MY APP", _.userAgent) + argumentTest("session_id", UUID.randomUUID().toString, _.sessionId.get) test("Argument - remote") { val builder = @@ -55,6 +58,7 @@ class SparkConnectClientBuilderParseTestSuite extends ConnectFunSuite { assert(builder.token.contains("nahnah")) assert(builder.userId.contains("x127")) assert(builder.options === Map(("user_name", "Q"), ("param1", "x"))) + assert(builder.sessionId.isEmpty) } test("Argument - use_ssl") { diff --git a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/SparkConnectClientSuite.scala b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/SparkConnectClientSuite.scala index 3436037809d..e483e0a7291 100644 --- a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/SparkConnectClientSuite.scala +++ b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/SparkConnectClientSuite.scala @@ -164,6 +164,9 @@ class SparkConnectClientSuite extends ConnectFunSuite with BeforeAndAfterEach { client => { assert(client.configuration.host == "localhost") assert(client.configuration.port == 1234) + assert(client.sessionId != null) + // Must be able to parse the UUID + assert(UUID.fromString(client.sessionId) != null) }), TestPackURI( "sc://localhost/;", @@ -193,6 +196,9 @@ class SparkConnectClientSuite extends ConnectFunSuite with BeforeAndAfterEach { TestPackURI("sc://host:123/;use_ssl=true", isCorrect = true), TestPackURI("sc://host:123/;token=mySecretToken", isCorrect = true), TestPackURI("sc://host:123/;token=", isCorrect = false), + TestPackURI("sc://host:123/;session_id=", isCorrect = false), + TestPackURI("sc://host:123/;session_id=abcdefgh", isCorrect = false), + TestPackURI(s"sc://host:123/;session_id=${UUID.randomUUID().toString}", isCorrect = true), TestPackURI("sc://host:123/;use_ssl=true;token=mySecretToken", isCorrect = true), TestPackURI("sc://host:123/;token=mySecretToken;use_ssl=true", isCorrect = true), TestPackURI("sc://host:123/;use_ssl=false;token=mySecretToken", isCorrect = false), diff --git a/connector/connect/docs/client-connection-string.md b/connector/connect/docs/client-connection-string.md index 6e5b0c80db7..ebab7cbff4f 100644 --- a/connector/connect/docs/client-connection-string.md +++ b/connector/connect/docs/client-connection-string.md @@ -91,6 +91,17 @@ sc://hostname:port/;param1=value;param2=value <i>Default: </i><pre>_SPARK_CONNECT_PYTHON</pre> in the Python client</td> <td><pre>user_agent=my_data_query_app</pre></td> </tr> + <tr> + <td>session_id</td> + <td>String</td> + <td>In addition to the user ID, the cache of Spark Sessions in the Spark Connect + server uses a session ID as the cache key. This option in the connection string + allows to provide this session ID to allow sharing Spark Sessions for the same users + for example across multiple languages. The value must be provided in a valid UUID + string format.<br/> + <i>Default: A UUID generated randomly.</td> + <td><pre>session_id=550e8400-e29b-41d4-a716-446655440000</pre></td> + </tr> </table> ## Examples diff --git a/python/pyspark/sql/connect/client/core.py b/python/pyspark/sql/connect/client/core.py index a7c3a92d3b1..5e6aacf5999 100644 --- a/python/pyspark/sql/connect/client/core.py +++ b/python/pyspark/sql/connect/client/core.py @@ -156,6 +156,7 @@ class ChannelBuilder: PARAM_TOKEN = "token" PARAM_USER_ID = "user_id" PARAM_USER_AGENT = "user_agent" + PARAM_SESSION_ID = "session_id" MAX_MESSAGE_LENGTH = 128 * 1024 * 1024 @staticmethod @@ -354,6 +355,22 @@ class ChannelBuilder: """ return self.params[key] + @property + def session_id(self) -> Optional[str]: + """ + Returns + ------- + The session_id extracted from the parameters of the connection string or `None` if not + specified. + """ + session_id = self.params.get(ChannelBuilder.PARAM_SESSION_ID, None) + if session_id is not None: + try: + uuid.UUID(session_id, version=4) + except ValueError as ve: + raise ValueError("Parameter value 'session_id' must be a valid UUID format.", ve) + return session_id + def toChannel(self) -> grpc.Channel: """ Applies the parameters of the connection string and creates a new @@ -628,10 +645,15 @@ class SparkConnectClient(object): if retry_policy: self._retry_policy.update(retry_policy) - # Generate a unique session ID for this client. This UUID must be unique to allow - # concurrent Spark sessions of the same user. If the channel is closed, creating - # a new client will create a new session ID. - self._session_id = str(uuid.uuid4()) + if self._builder.session_id is None: + # Generate a unique session ID for this client. This UUID must be unique to allow + # concurrent Spark sessions of the same user. If the channel is closed, creating + # a new client will create a new session ID. + self._session_id = str(uuid.uuid4()) + else: + # Use the pre-defined session ID. + self._session_id = str(self._builder.session_id) + if self._builder.userId is not None: self._user_id = self._builder.userId elif user_id is not None: diff --git a/python/pyspark/sql/tests/connect/client/test_client.py b/python/pyspark/sql/tests/connect/client/test_client.py index 9276b88e153..9782add92f4 100644 --- a/python/pyspark/sql/tests/connect/client/test_client.py +++ b/python/pyspark/sql/tests/connect/client/test_client.py @@ -16,6 +16,7 @@ # import unittest +import uuid from typing import Optional from pyspark.sql.connect.client import SparkConnectClient, ChannelBuilder @@ -88,6 +89,12 @@ class SparkConnectClientTestCase(unittest.TestCase): client.close() self.assertTrue(client.is_closed) + def test_channel_builder_with_session(self): + dummy = str(uuid.uuid4()) + chan = ChannelBuilder(f"sc://foo/;session_id={dummy}") + client = SparkConnectClient(chan) + self.assertEqual(client._session_id, chan.session_id) + class MockService: # Simplest mock of the SparkConnectService. diff --git a/python/pyspark/sql/tests/connect/test_connect_basic.py b/python/pyspark/sql/tests/connect/test_connect_basic.py index 0687fc9f313..63b65ecce1a 100644 --- a/python/pyspark/sql/tests/connect/test_connect_basic.py +++ b/python/pyspark/sql/tests/connect/test_connect_basic.py @@ -23,6 +23,7 @@ import random import shutil import string import tempfile +import uuid from collections import defaultdict from pyspark.errors import ( @@ -76,7 +77,7 @@ if should_test_connect: from pyspark.sql.connect.dataframe import DataFrame as CDataFrame from pyspark.sql import functions as SF from pyspark.sql.connect import functions as CF - from pyspark.sql.connect.client.core import Retrying + from pyspark.sql.connect.client.core import Retrying, SparkConnectClient class SparkConnectSQLTestCase(ReusedConnectTestCase, SQLTestUtils, PandasOnSparkTestUtils): @@ -3522,6 +3523,21 @@ class ChannelBuilderTests(unittest.TestCase): md = chan.metadata() self.assertEqual([("param1", "120 21"), ("x-my-header", "abcd")], md) + def test_metadata(self): + id = str(uuid.uuid4()) + chan = ChannelBuilder(f"sc://host/;session_id={id}") + self.assertEqual(id, chan.session_id) + + with self.assertRaises(ValueError) as ve: + chan = ChannelBuilder("sc://host/;session_id=abcd") + SparkConnectClient(chan) + self.assertIn( + "Parameter value 'session_id' must be a valid UUID format.", str(ve.exception) + ) + + chan = ChannelBuilder("sc://host/") + self.assertIsNone(chan.session_id) + if __name__ == "__main__": from pyspark.sql.tests.connect.test_connect_basic import * # noqa: F401 --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org