This is an automated email from the ASF dual-hosted git repository.

hvanhovell pushed a commit to branch branch-3.5
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/branch-3.5 by this push:
     new 5a71a2e4a3d [SPARK-43429][CONNECT] Add Default & Active SparkSession 
for Scala Client
5a71a2e4a3d is described below

commit 5a71a2e4a3d6ad5c6393b64fb76f571051ee3c94
Author: Herman van Hovell <her...@databricks.com>
AuthorDate: Tue Aug 8 04:15:07 2023 +0200

    [SPARK-43429][CONNECT] Add Default & Active SparkSession for Scala Client
    
    ### What changes were proposed in this pull request?
    This adds the `default` and `active` session variables to `SparkSession`:
    - `default` session is global value. It is typically the first session 
created through `getOrCreate`. It can be changed through `set` or `clear`. If 
the session is closed and it is the `default` session we clear the `default` 
session.
    - `active` session is a thread local value. It is typically the first 
session created in this thread or it inherits is value from its parent thread. 
It can be changed through `set` or `clear`, please note that these methods 
operate thread locally, so they won't change the parent or children. If the 
session is closed and it is the `active` session for the current thread then we 
clear the active value (only for the current thread!).
    
    ### Why are the changes needed?
    To increase compatibility with the existing SparkSession API in `sql/core`.
    
    ### Does this PR introduce _any_ user-facing change?
    Yes. It adds a couple methods that were missing from the Scala Client.
    
    ### How was this patch tested?
    Added tests to `SparkSessionSuite`.
    
    Closes #42367 from hvanhovell/SPARK-43429.
    
    Authored-by: Herman van Hovell <her...@databricks.com>
    Signed-off-by: Herman van Hovell <her...@databricks.com>
    (cherry picked from commit 7493c5764f9644878babacccd4f688fe13ef84aa)
    Signed-off-by: Herman van Hovell <her...@databricks.com>
---
 .../scala/org/apache/spark/sql/SparkSession.scala  | 100 ++++++++++++--
 .../org/apache/spark/sql/SparkSessionSuite.scala   | 144 +++++++++++++++++++--
 .../CheckConnectJvmClientCompatibility.scala       |   2 -
 3 files changed, 225 insertions(+), 21 deletions(-)

diff --git 
a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/SparkSession.scala
 
b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/SparkSession.scala
index 355d7edadc7..7367ed153f7 100644
--- 
a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/SparkSession.scala
+++ 
b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/SparkSession.scala
@@ -19,7 +19,7 @@ package org.apache.spark.sql
 import java.io.Closeable
 import java.net.URI
 import java.util.concurrent.TimeUnit._
-import java.util.concurrent.atomic.AtomicLong
+import java.util.concurrent.atomic.{AtomicLong, AtomicReference}
 
 import scala.collection.JavaConverters._
 import scala.reflect.runtime.universe.TypeTag
@@ -730,6 +730,23 @@ object SparkSession extends Logging {
       override def load(c: Configuration): SparkSession = create(c)
     })
 
+  /** The active SparkSession for the current thread. */
+  private val activeThreadSession = new InheritableThreadLocal[SparkSession]
+
+  /** Reference to the root SparkSession. */
+  private val defaultSession = new AtomicReference[SparkSession]
+
+  /**
+   * Set the (global) default [[SparkSession]], and (thread-local) active 
[[SparkSession]] when
+   * they are not set yet.
+   */
+  private def setDefaultAndActiveSession(session: SparkSession): Unit = {
+    defaultSession.compareAndSet(null, session)
+    if (getActiveSession.isEmpty) {
+      setActiveSession(session)
+    }
+  }
+
   /**
    * Create a new [[SparkSession]] based on the connect client 
[[Configuration]].
    */
@@ -742,8 +759,17 @@ object SparkSession extends Logging {
    */
   private[sql] def onSessionClose(session: SparkSession): Unit = {
     sessions.invalidate(session.client.configuration)
+    defaultSession.compareAndSet(session, null)
+    if (getActiveSession.contains(session)) {
+      clearActiveSession()
+    }
   }
 
+  /**
+   * Creates a [[SparkSession.Builder]] for constructing a [[SparkSession]].
+   *
+   * @since 3.4.0
+   */
   def builder(): Builder = new Builder()
 
   private[sql] lazy val cleaner = {
@@ -799,10 +825,15 @@ object SparkSession extends Logging {
      *
      * This will always return a newly created session.
      *
+     * This method will update the default and/or active session if they are 
not set.
+     *
      * @since 3.5.0
      */
     def create(): SparkSession = {
-      
tryCreateSessionFromClient().getOrElse(SparkSession.this.create(builder.configuration))
+      val session = tryCreateSessionFromClient()
+        .getOrElse(SparkSession.this.create(builder.configuration))
+      setDefaultAndActiveSession(session)
+      session
     }
 
     /**
@@ -811,30 +842,79 @@ object SparkSession extends Logging {
      * If a session exist with the same configuration that is returned instead 
of creating a new
      * session.
      *
+     * This method will update the default and/or active session if they are 
not set.
+     *
      * @since 3.5.0
      */
     def getOrCreate(): SparkSession = {
-      
tryCreateSessionFromClient().getOrElse(sessions.get(builder.configuration))
+      val session = tryCreateSessionFromClient()
+        .getOrElse(sessions.get(builder.configuration))
+      setDefaultAndActiveSession(session)
+      session
     }
   }
 
-  def getActiveSession: Option[SparkSession] = {
-    throw new UnsupportedOperationException("getActiveSession is not 
supported")
+  /**
+   * Returns the default SparkSession.
+   *
+   * @since 3.5.0
+   */
+  def getDefaultSession: Option[SparkSession] = Option(defaultSession.get())
+
+  /**
+   * Sets the default SparkSession.
+   *
+   * @since 3.5.0
+   */
+  def setDefaultSession(session: SparkSession): Unit = {
+    defaultSession.set(session)
   }
 
-  def getDefaultSession: Option[SparkSession] = {
-    throw new UnsupportedOperationException("getDefaultSession is not 
supported")
+  /**
+   * Clears the default SparkSession.
+   *
+   * @since 3.5.0
+   */
+  def clearDefaultSession(): Unit = {
+    defaultSession.set(null)
   }
 
+  /**
+   * Returns the active SparkSession for the current thread.
+   *
+   * @since 3.5.0
+   */
+  def getActiveSession: Option[SparkSession] = 
Option(activeThreadSession.get())
+
+  /**
+   * Changes the SparkSession that will be returned in this thread and its 
children when
+   * SparkSession.getOrCreate() is called. This can be used to ensure that a 
given thread receives
+   * an isolated SparkSession.
+   *
+   * @since 3.5.0
+   */
   def setActiveSession(session: SparkSession): Unit = {
-    throw new UnsupportedOperationException("setActiveSession is not 
supported")
+    activeThreadSession.set(session)
   }
 
+  /**
+   * Clears the active SparkSession for current thread.
+   *
+   * @since 3.5.0
+   */
   def clearActiveSession(): Unit = {
-    throw new UnsupportedOperationException("clearActiveSession is not 
supported")
+    activeThreadSession.remove()
   }
 
+  /**
+   * Returns the currently active SparkSession, otherwise the default one. If 
there is no default
+   * SparkSession, throws an exception.
+   *
+   * @since 3.5.0
+   */
   def active: SparkSession = {
-    throw new UnsupportedOperationException("active is not supported")
+    getActiveSession
+      .orElse(getDefaultSession)
+      .getOrElse(throw new IllegalStateException("No active or default Spark 
session found"))
   }
 }
diff --git 
a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/SparkSessionSuite.scala
 
b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/SparkSessionSuite.scala
index 97fb46bf48a..f06744399f8 100644
--- 
a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/SparkSessionSuite.scala
+++ 
b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/SparkSessionSuite.scala
@@ -16,6 +16,10 @@
  */
 package org.apache.spark.sql
 
+import java.util.concurrent.{Executors, Phaser}
+
+import scala.util.control.NonFatal
+
 import io.grpc.{CallOptions, Channel, ClientCall, ClientInterceptor, 
MethodDescriptor}
 
 import org.apache.spark.sql.connect.client.util.ConnectFunSuite
@@ -24,6 +28,10 @@ import 
org.apache.spark.sql.connect.client.util.ConnectFunSuite
  * Tests for non-dataframe related SparkSession operations.
  */
 class SparkSessionSuite extends ConnectFunSuite {
+  private val connectionString1: String = "sc://test.it:17845"
+  private val connectionString2: String = "sc://test.me:14099"
+  private val connectionString3: String = "sc://doit:16845"
+
   test("default") {
     val session = SparkSession.builder().getOrCreate()
     assert(session.client.configuration.host == "localhost")
@@ -32,16 +40,15 @@ class SparkSessionSuite extends ConnectFunSuite {
   }
 
   test("remote") {
-    val session = 
SparkSession.builder().remote("sc://test.me:14099").getOrCreate()
+    val session = 
SparkSession.builder().remote(connectionString2).getOrCreate()
     assert(session.client.configuration.host == "test.me")
     assert(session.client.configuration.port == 14099)
     session.close()
   }
 
   test("getOrCreate") {
-    val connectionString = "sc://test.it:17865"
-    val session1 = 
SparkSession.builder().remote(connectionString).getOrCreate()
-    val session2 = 
SparkSession.builder().remote(connectionString).getOrCreate()
+    val session1 = 
SparkSession.builder().remote(connectionString1).getOrCreate()
+    val session2 = 
SparkSession.builder().remote(connectionString1).getOrCreate()
     try {
       assert(session1 eq session2)
     } finally {
@@ -51,9 +58,8 @@ class SparkSessionSuite extends ConnectFunSuite {
   }
 
   test("create") {
-    val connectionString = "sc://test.it:17845"
-    val session1 = SparkSession.builder().remote(connectionString).create()
-    val session2 = SparkSession.builder().remote(connectionString).create()
+    val session1 = SparkSession.builder().remote(connectionString1).create()
+    val session2 = SparkSession.builder().remote(connectionString1).create()
     try {
       assert(session1 ne session2)
       assert(session1.client.configuration == session2.client.configuration)
@@ -64,8 +70,7 @@ class SparkSessionSuite extends ConnectFunSuite {
   }
 
   test("newSession") {
-    val connectionString = "sc://doit:16845"
-    val session1 = SparkSession.builder().remote(connectionString).create()
+    val session1 = SparkSession.builder().remote(connectionString3).create()
     val session2 = session1.newSession()
     try {
       assert(session1 ne session2)
@@ -92,5 +97,126 @@ class SparkSessionSuite extends ConnectFunSuite {
     assertThrows[RuntimeException] {
       session.range(10).count()
     }
+    session.close()
+  }
+
+  test("Default/Active session") {
+    // Make sure we start with a clean slate.
+    SparkSession.clearDefaultSession()
+    SparkSession.clearActiveSession()
+    assert(SparkSession.getDefaultSession.isEmpty)
+    assert(SparkSession.getActiveSession.isEmpty)
+    intercept[IllegalStateException](SparkSession.active)
+
+    // Create a session
+    val session1 = 
SparkSession.builder().remote(connectionString1).getOrCreate()
+    assert(SparkSession.getDefaultSession.contains(session1))
+    assert(SparkSession.getActiveSession.contains(session1))
+    assert(SparkSession.active == session1)
+
+    // Create another session...
+    val session2 = SparkSession.builder().remote(connectionString2).create()
+    assert(SparkSession.getDefaultSession.contains(session1))
+    assert(SparkSession.getActiveSession.contains(session1))
+    SparkSession.setActiveSession(session2)
+    assert(SparkSession.getDefaultSession.contains(session1))
+    assert(SparkSession.getActiveSession.contains(session2))
+
+    // Clear sessions
+    SparkSession.clearDefaultSession()
+    assert(SparkSession.getDefaultSession.isEmpty)
+    SparkSession.clearActiveSession()
+    assert(SparkSession.getDefaultSession.isEmpty)
+
+    // Flip sessions
+    SparkSession.setActiveSession(session1)
+    SparkSession.setDefaultSession(session2)
+    assert(SparkSession.getDefaultSession.contains(session2))
+    assert(SparkSession.getActiveSession.contains(session1))
+
+    // Close session1
+    session1.close()
+    assert(SparkSession.getDefaultSession.contains(session2))
+    assert(SparkSession.getActiveSession.isEmpty)
+
+    // Close session2
+    session2.close()
+    assert(SparkSession.getDefaultSession.isEmpty)
+    assert(SparkSession.getActiveSession.isEmpty)
+  }
+
+  test("active session in multiple threads") {
+    SparkSession.clearDefaultSession()
+    SparkSession.clearActiveSession()
+    val session1 = SparkSession.builder().remote(connectionString1).create()
+    val session2 = SparkSession.builder().remote(connectionString1).create()
+    SparkSession.setActiveSession(session2)
+    assert(SparkSession.getDefaultSession.contains(session1))
+    assert(SparkSession.getActiveSession.contains(session2))
+
+    val phaser = new Phaser(2)
+    val executor = Executors.newFixedThreadPool(2)
+    def execute(block: Phaser => Unit): java.util.concurrent.Future[Boolean] = 
{
+      executor.submit[Boolean] { () =>
+        try {
+          block(phaser)
+          true
+        } catch {
+          case NonFatal(e) =>
+            phaser.forceTermination()
+            throw e
+        }
+      }
+    }
+
+    try {
+      val script1 = execute { phaser =>
+        phaser.arriveAndAwaitAdvance()
+        assert(SparkSession.getDefaultSession.contains(session1))
+        assert(SparkSession.getActiveSession.contains(session2))
+
+        phaser.arriveAndAwaitAdvance()
+        assert(SparkSession.getDefaultSession.contains(session1))
+        assert(SparkSession.getActiveSession.contains(session2))
+        session1.close()
+
+        phaser.arriveAndAwaitAdvance()
+        assert(SparkSession.getDefaultSession.isEmpty)
+        assert(SparkSession.getActiveSession.contains(session2))
+        SparkSession.clearActiveSession()
+
+        phaser.arriveAndAwaitAdvance()
+        assert(SparkSession.getDefaultSession.isEmpty)
+        assert(SparkSession.getActiveSession.isEmpty)
+      }
+      val script2 = execute { phaser =>
+        phaser.arriveAndAwaitAdvance()
+        assert(SparkSession.getDefaultSession.contains(session1))
+        assert(SparkSession.getActiveSession.contains(session2))
+        SparkSession.clearActiveSession()
+        val internalSession = 
SparkSession.builder().remote(connectionString3).getOrCreate()
+
+        phaser.arriveAndAwaitAdvance()
+        assert(SparkSession.getDefaultSession.contains(session1))
+        assert(SparkSession.getActiveSession.contains(internalSession))
+
+        phaser.arriveAndAwaitAdvance()
+        assert(SparkSession.getDefaultSession.isEmpty)
+        assert(SparkSession.getActiveSession.contains(internalSession))
+
+        phaser.arriveAndAwaitAdvance()
+        assert(SparkSession.getDefaultSession.isEmpty)
+        assert(SparkSession.getActiveSession.contains(internalSession))
+        internalSession.close()
+        assert(SparkSession.getActiveSession.isEmpty)
+      }
+      assert(script1.get())
+      assert(script2.get())
+      assert(SparkSession.getActiveSession.contains(session2))
+      session2.close()
+      assert(SparkSession.getActiveSession.isEmpty)
+    } finally {
+      executor.shutdown()
+    }
   }
 }
diff --git 
a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/CheckConnectJvmClientCompatibility.scala
 
b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/CheckConnectJvmClientCompatibility.scala
index 6e577e0f212..2bf9c41fb2c 100644
--- 
a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/CheckConnectJvmClientCompatibility.scala
+++ 
b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/CheckConnectJvmClientCompatibility.scala
@@ -207,8 +207,6 @@ object CheckConnectJvmClientCompatibility {
       
ProblemFilters.exclude[Problem]("org.apache.spark.sql.RelationalGroupedDataset.apply"),
 
       // SparkSession
-      
ProblemFilters.exclude[Problem]("org.apache.spark.sql.SparkSession.clearDefaultSession"),
-      
ProblemFilters.exclude[Problem]("org.apache.spark.sql.SparkSession.setDefaultSession"),
       
ProblemFilters.exclude[Problem]("org.apache.spark.sql.SparkSession.sparkContext"),
       
ProblemFilters.exclude[Problem]("org.apache.spark.sql.SparkSession.sharedState"),
       
ProblemFilters.exclude[Problem]("org.apache.spark.sql.SparkSession.sessionState"),


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to