This is an automated email from the ASF dual-hosted git repository. hvanhovell pushed a commit to branch branch-3.5 in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/branch-3.5 by this push: new 5a71a2e4a3d [SPARK-43429][CONNECT] Add Default & Active SparkSession for Scala Client 5a71a2e4a3d is described below commit 5a71a2e4a3d6ad5c6393b64fb76f571051ee3c94 Author: Herman van Hovell <her...@databricks.com> AuthorDate: Tue Aug 8 04:15:07 2023 +0200 [SPARK-43429][CONNECT] Add Default & Active SparkSession for Scala Client ### What changes were proposed in this pull request? This adds the `default` and `active` session variables to `SparkSession`: - `default` session is global value. It is typically the first session created through `getOrCreate`. It can be changed through `set` or `clear`. If the session is closed and it is the `default` session we clear the `default` session. - `active` session is a thread local value. It is typically the first session created in this thread or it inherits is value from its parent thread. It can be changed through `set` or `clear`, please note that these methods operate thread locally, so they won't change the parent or children. If the session is closed and it is the `active` session for the current thread then we clear the active value (only for the current thread!). ### Why are the changes needed? To increase compatibility with the existing SparkSession API in `sql/core`. ### Does this PR introduce _any_ user-facing change? Yes. It adds a couple methods that were missing from the Scala Client. ### How was this patch tested? Added tests to `SparkSessionSuite`. Closes #42367 from hvanhovell/SPARK-43429. Authored-by: Herman van Hovell <her...@databricks.com> Signed-off-by: Herman van Hovell <her...@databricks.com> (cherry picked from commit 7493c5764f9644878babacccd4f688fe13ef84aa) Signed-off-by: Herman van Hovell <her...@databricks.com> --- .../scala/org/apache/spark/sql/SparkSession.scala | 100 ++++++++++++-- .../org/apache/spark/sql/SparkSessionSuite.scala | 144 +++++++++++++++++++-- .../CheckConnectJvmClientCompatibility.scala | 2 - 3 files changed, 225 insertions(+), 21 deletions(-) diff --git a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/SparkSession.scala b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/SparkSession.scala index 355d7edadc7..7367ed153f7 100644 --- a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/SparkSession.scala +++ b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/SparkSession.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql import java.io.Closeable import java.net.URI import java.util.concurrent.TimeUnit._ -import java.util.concurrent.atomic.AtomicLong +import java.util.concurrent.atomic.{AtomicLong, AtomicReference} import scala.collection.JavaConverters._ import scala.reflect.runtime.universe.TypeTag @@ -730,6 +730,23 @@ object SparkSession extends Logging { override def load(c: Configuration): SparkSession = create(c) }) + /** The active SparkSession for the current thread. */ + private val activeThreadSession = new InheritableThreadLocal[SparkSession] + + /** Reference to the root SparkSession. */ + private val defaultSession = new AtomicReference[SparkSession] + + /** + * Set the (global) default [[SparkSession]], and (thread-local) active [[SparkSession]] when + * they are not set yet. + */ + private def setDefaultAndActiveSession(session: SparkSession): Unit = { + defaultSession.compareAndSet(null, session) + if (getActiveSession.isEmpty) { + setActiveSession(session) + } + } + /** * Create a new [[SparkSession]] based on the connect client [[Configuration]]. */ @@ -742,8 +759,17 @@ object SparkSession extends Logging { */ private[sql] def onSessionClose(session: SparkSession): Unit = { sessions.invalidate(session.client.configuration) + defaultSession.compareAndSet(session, null) + if (getActiveSession.contains(session)) { + clearActiveSession() + } } + /** + * Creates a [[SparkSession.Builder]] for constructing a [[SparkSession]]. + * + * @since 3.4.0 + */ def builder(): Builder = new Builder() private[sql] lazy val cleaner = { @@ -799,10 +825,15 @@ object SparkSession extends Logging { * * This will always return a newly created session. * + * This method will update the default and/or active session if they are not set. + * * @since 3.5.0 */ def create(): SparkSession = { - tryCreateSessionFromClient().getOrElse(SparkSession.this.create(builder.configuration)) + val session = tryCreateSessionFromClient() + .getOrElse(SparkSession.this.create(builder.configuration)) + setDefaultAndActiveSession(session) + session } /** @@ -811,30 +842,79 @@ object SparkSession extends Logging { * If a session exist with the same configuration that is returned instead of creating a new * session. * + * This method will update the default and/or active session if they are not set. + * * @since 3.5.0 */ def getOrCreate(): SparkSession = { - tryCreateSessionFromClient().getOrElse(sessions.get(builder.configuration)) + val session = tryCreateSessionFromClient() + .getOrElse(sessions.get(builder.configuration)) + setDefaultAndActiveSession(session) + session } } - def getActiveSession: Option[SparkSession] = { - throw new UnsupportedOperationException("getActiveSession is not supported") + /** + * Returns the default SparkSession. + * + * @since 3.5.0 + */ + def getDefaultSession: Option[SparkSession] = Option(defaultSession.get()) + + /** + * Sets the default SparkSession. + * + * @since 3.5.0 + */ + def setDefaultSession(session: SparkSession): Unit = { + defaultSession.set(session) } - def getDefaultSession: Option[SparkSession] = { - throw new UnsupportedOperationException("getDefaultSession is not supported") + /** + * Clears the default SparkSession. + * + * @since 3.5.0 + */ + def clearDefaultSession(): Unit = { + defaultSession.set(null) } + /** + * Returns the active SparkSession for the current thread. + * + * @since 3.5.0 + */ + def getActiveSession: Option[SparkSession] = Option(activeThreadSession.get()) + + /** + * Changes the SparkSession that will be returned in this thread and its children when + * SparkSession.getOrCreate() is called. This can be used to ensure that a given thread receives + * an isolated SparkSession. + * + * @since 3.5.0 + */ def setActiveSession(session: SparkSession): Unit = { - throw new UnsupportedOperationException("setActiveSession is not supported") + activeThreadSession.set(session) } + /** + * Clears the active SparkSession for current thread. + * + * @since 3.5.0 + */ def clearActiveSession(): Unit = { - throw new UnsupportedOperationException("clearActiveSession is not supported") + activeThreadSession.remove() } + /** + * Returns the currently active SparkSession, otherwise the default one. If there is no default + * SparkSession, throws an exception. + * + * @since 3.5.0 + */ def active: SparkSession = { - throw new UnsupportedOperationException("active is not supported") + getActiveSession + .orElse(getDefaultSession) + .getOrElse(throw new IllegalStateException("No active or default Spark session found")) } } diff --git a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/SparkSessionSuite.scala b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/SparkSessionSuite.scala index 97fb46bf48a..f06744399f8 100644 --- a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/SparkSessionSuite.scala +++ b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/SparkSessionSuite.scala @@ -16,6 +16,10 @@ */ package org.apache.spark.sql +import java.util.concurrent.{Executors, Phaser} + +import scala.util.control.NonFatal + import io.grpc.{CallOptions, Channel, ClientCall, ClientInterceptor, MethodDescriptor} import org.apache.spark.sql.connect.client.util.ConnectFunSuite @@ -24,6 +28,10 @@ import org.apache.spark.sql.connect.client.util.ConnectFunSuite * Tests for non-dataframe related SparkSession operations. */ class SparkSessionSuite extends ConnectFunSuite { + private val connectionString1: String = "sc://test.it:17845" + private val connectionString2: String = "sc://test.me:14099" + private val connectionString3: String = "sc://doit:16845" + test("default") { val session = SparkSession.builder().getOrCreate() assert(session.client.configuration.host == "localhost") @@ -32,16 +40,15 @@ class SparkSessionSuite extends ConnectFunSuite { } test("remote") { - val session = SparkSession.builder().remote("sc://test.me:14099").getOrCreate() + val session = SparkSession.builder().remote(connectionString2).getOrCreate() assert(session.client.configuration.host == "test.me") assert(session.client.configuration.port == 14099) session.close() } test("getOrCreate") { - val connectionString = "sc://test.it:17865" - val session1 = SparkSession.builder().remote(connectionString).getOrCreate() - val session2 = SparkSession.builder().remote(connectionString).getOrCreate() + val session1 = SparkSession.builder().remote(connectionString1).getOrCreate() + val session2 = SparkSession.builder().remote(connectionString1).getOrCreate() try { assert(session1 eq session2) } finally { @@ -51,9 +58,8 @@ class SparkSessionSuite extends ConnectFunSuite { } test("create") { - val connectionString = "sc://test.it:17845" - val session1 = SparkSession.builder().remote(connectionString).create() - val session2 = SparkSession.builder().remote(connectionString).create() + val session1 = SparkSession.builder().remote(connectionString1).create() + val session2 = SparkSession.builder().remote(connectionString1).create() try { assert(session1 ne session2) assert(session1.client.configuration == session2.client.configuration) @@ -64,8 +70,7 @@ class SparkSessionSuite extends ConnectFunSuite { } test("newSession") { - val connectionString = "sc://doit:16845" - val session1 = SparkSession.builder().remote(connectionString).create() + val session1 = SparkSession.builder().remote(connectionString3).create() val session2 = session1.newSession() try { assert(session1 ne session2) @@ -92,5 +97,126 @@ class SparkSessionSuite extends ConnectFunSuite { assertThrows[RuntimeException] { session.range(10).count() } + session.close() + } + + test("Default/Active session") { + // Make sure we start with a clean slate. + SparkSession.clearDefaultSession() + SparkSession.clearActiveSession() + assert(SparkSession.getDefaultSession.isEmpty) + assert(SparkSession.getActiveSession.isEmpty) + intercept[IllegalStateException](SparkSession.active) + + // Create a session + val session1 = SparkSession.builder().remote(connectionString1).getOrCreate() + assert(SparkSession.getDefaultSession.contains(session1)) + assert(SparkSession.getActiveSession.contains(session1)) + assert(SparkSession.active == session1) + + // Create another session... + val session2 = SparkSession.builder().remote(connectionString2).create() + assert(SparkSession.getDefaultSession.contains(session1)) + assert(SparkSession.getActiveSession.contains(session1)) + SparkSession.setActiveSession(session2) + assert(SparkSession.getDefaultSession.contains(session1)) + assert(SparkSession.getActiveSession.contains(session2)) + + // Clear sessions + SparkSession.clearDefaultSession() + assert(SparkSession.getDefaultSession.isEmpty) + SparkSession.clearActiveSession() + assert(SparkSession.getDefaultSession.isEmpty) + + // Flip sessions + SparkSession.setActiveSession(session1) + SparkSession.setDefaultSession(session2) + assert(SparkSession.getDefaultSession.contains(session2)) + assert(SparkSession.getActiveSession.contains(session1)) + + // Close session1 + session1.close() + assert(SparkSession.getDefaultSession.contains(session2)) + assert(SparkSession.getActiveSession.isEmpty) + + // Close session2 + session2.close() + assert(SparkSession.getDefaultSession.isEmpty) + assert(SparkSession.getActiveSession.isEmpty) + } + + test("active session in multiple threads") { + SparkSession.clearDefaultSession() + SparkSession.clearActiveSession() + val session1 = SparkSession.builder().remote(connectionString1).create() + val session2 = SparkSession.builder().remote(connectionString1).create() + SparkSession.setActiveSession(session2) + assert(SparkSession.getDefaultSession.contains(session1)) + assert(SparkSession.getActiveSession.contains(session2)) + + val phaser = new Phaser(2) + val executor = Executors.newFixedThreadPool(2) + def execute(block: Phaser => Unit): java.util.concurrent.Future[Boolean] = { + executor.submit[Boolean] { () => + try { + block(phaser) + true + } catch { + case NonFatal(e) => + phaser.forceTermination() + throw e + } + } + } + + try { + val script1 = execute { phaser => + phaser.arriveAndAwaitAdvance() + assert(SparkSession.getDefaultSession.contains(session1)) + assert(SparkSession.getActiveSession.contains(session2)) + + phaser.arriveAndAwaitAdvance() + assert(SparkSession.getDefaultSession.contains(session1)) + assert(SparkSession.getActiveSession.contains(session2)) + session1.close() + + phaser.arriveAndAwaitAdvance() + assert(SparkSession.getDefaultSession.isEmpty) + assert(SparkSession.getActiveSession.contains(session2)) + SparkSession.clearActiveSession() + + phaser.arriveAndAwaitAdvance() + assert(SparkSession.getDefaultSession.isEmpty) + assert(SparkSession.getActiveSession.isEmpty) + } + val script2 = execute { phaser => + phaser.arriveAndAwaitAdvance() + assert(SparkSession.getDefaultSession.contains(session1)) + assert(SparkSession.getActiveSession.contains(session2)) + SparkSession.clearActiveSession() + val internalSession = SparkSession.builder().remote(connectionString3).getOrCreate() + + phaser.arriveAndAwaitAdvance() + assert(SparkSession.getDefaultSession.contains(session1)) + assert(SparkSession.getActiveSession.contains(internalSession)) + + phaser.arriveAndAwaitAdvance() + assert(SparkSession.getDefaultSession.isEmpty) + assert(SparkSession.getActiveSession.contains(internalSession)) + + phaser.arriveAndAwaitAdvance() + assert(SparkSession.getDefaultSession.isEmpty) + assert(SparkSession.getActiveSession.contains(internalSession)) + internalSession.close() + assert(SparkSession.getActiveSession.isEmpty) + } + assert(script1.get()) + assert(script2.get()) + assert(SparkSession.getActiveSession.contains(session2)) + session2.close() + assert(SparkSession.getActiveSession.isEmpty) + } finally { + executor.shutdown() + } } } diff --git a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/CheckConnectJvmClientCompatibility.scala b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/CheckConnectJvmClientCompatibility.scala index 6e577e0f212..2bf9c41fb2c 100644 --- a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/CheckConnectJvmClientCompatibility.scala +++ b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/CheckConnectJvmClientCompatibility.scala @@ -207,8 +207,6 @@ object CheckConnectJvmClientCompatibility { ProblemFilters.exclude[Problem]("org.apache.spark.sql.RelationalGroupedDataset.apply"), // SparkSession - ProblemFilters.exclude[Problem]("org.apache.spark.sql.SparkSession.clearDefaultSession"), - ProblemFilters.exclude[Problem]("org.apache.spark.sql.SparkSession.setDefaultSession"), ProblemFilters.exclude[Problem]("org.apache.spark.sql.SparkSession.sparkContext"), ProblemFilters.exclude[Problem]("org.apache.spark.sql.SparkSession.sharedState"), ProblemFilters.exclude[Problem]("org.apache.spark.sql.SparkSession.sessionState"), --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org