This is an automated email from the ASF dual-hosted git repository.

gurwls223 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new a122b8acc2f4 [SPARK-46075][CONNECT] Improvements to 
SparkConnectSessionManager
a122b8acc2f4 is described below

commit a122b8acc2f47c58e8891a5f1464a588f77750e7
Author: Juliusz Sompolski <ju...@databricks.com>
AuthorDate: Tue Dec 12 09:40:56 2023 -0800

    [SPARK-46075][CONNECT] Improvements to SparkConnectSessionManager
    
    ### What changes were proposed in this pull request?
    
    This is factored out from https://github.com/apache/spark/pull/43913 and is 
a continuation to https://github.com/apache/spark/pull/43546 when 
SparkConnectSessionManager was introduced.
    
    We want to remove the use a Guava cache as session cache, and have our 
custom logic with more control. This refactors the Session Manager and adds 
more tests.
    
    We introduce a mechanism that mirrors SparkConnectExecutionManager instead.
    
    ### Why are the changes needed?
    
    With guava cache, only a single "inactivity timeout" can be specified for 
the whole cache. This can't be for example overriden per session. The actual 
invalidation also happens not in it's own thread inside guava, but it's 
work-stealing lazily piggy backed to other operations on the cache, making it 
opaque when session removal will actually happen.
    
    ### Does this PR introduce _any_ user-facing change?
    
    No.
    
    ### How was this patch tested?
    
    SparkConnectSessionManagerSuite added.
    
    ### Was this patch authored or co-authored using generative AI tooling?
    
    Github Copilot was assisting in some boilerplate auto-completion.
    
    Generated-by: Github Copilot
    
    Closes #43985 from juliuszsompolski/SPARK-46075.
    
    Authored-by: Juliusz Sompolski <ju...@databricks.com>
    Signed-off-by: Hyukjin Kwon <gurwls...@apache.org>
---
 .../apache/spark/sql/connect/config/Connect.scala  |  11 +-
 .../spark/sql/connect/service/ExecuteHolder.scala  |  39 ++--
 .../spark/sql/connect/service/SessionHolder.scala  |  80 ++++++--
 .../service/SparkConnectExecutionManager.scala     |  48 +++--
 .../service/SparkConnectSessionManager.scala       | 224 ++++++++++++++++-----
 .../spark/sql/connect/SparkConnectServerTest.scala |   2 +-
 .../service/SparkConnectSessionManagerSuite.scala  | 137 +++++++++++++
 7 files changed, 429 insertions(+), 112 deletions(-)

diff --git 
a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/config/Connect.scala
 
b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/config/Connect.scala
index f7aa98af2fa3..ab4f06d508a0 100644
--- 
a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/config/Connect.scala
+++ 
b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/config/Connect.scala
@@ -78,7 +78,8 @@ object Connect {
   val CONNECT_SESSION_MANAGER_DEFAULT_SESSION_TIMEOUT =
     buildStaticConf("spark.connect.session.manager.defaultSessionTimeout")
       .internal()
-      .doc("Timeout after which sessions without any new incoming RPC will be 
removed.")
+      .doc("Timeout after which sessions without any new incoming RPC will be 
removed. " +
+        "Setting it to -1 indicates that sessions should be kept forever.")
       .version("4.0.0")
       .timeConf(TimeUnit.MILLISECONDS)
       .createWithDefaultString("60m")
@@ -93,6 +94,14 @@ object Connect {
       .intConf
       .createWithDefaultString("1000")
 
+  val CONNECT_SESSION_MANAGER_MAINTENANCE_INTERVAL =
+    buildStaticConf("spark.connect.session.manager.maintenanceInterval")
+      .internal()
+      .doc("Interval at which session manager will search for expired sessions 
to remove.")
+      .version("4.0.0")
+      .timeConf(TimeUnit.MILLISECONDS)
+      .createWithDefaultString("30s")
+
   val CONNECT_EXECUTE_MANAGER_DETACHED_TIMEOUT =
     buildStaticConf("spark.connect.execute.manager.detachedTimeout")
       .internal()
diff --git 
a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/ExecuteHolder.scala
 
b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/ExecuteHolder.scala
index 9e97ded5bf8a..f03f81326064 100644
--- 
a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/ExecuteHolder.scala
+++ 
b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/ExecuteHolder.scala
@@ -95,17 +95,17 @@ private[connect] class ExecuteHolder(
   private val runner: ExecuteThreadRunner = new ExecuteThreadRunner(this)
 
   /** System.currentTimeMillis when this ExecuteHolder was created. */
-  val creationTime = System.currentTimeMillis()
+  val creationTimeMs = System.currentTimeMillis()
 
   /**
    * None if there is currently an attached RPC (grpcResponseSenders not empty 
or during initial
    * ExecutePlan handler). Otherwise, the System.currentTimeMillis when the 
last RPC detached
    * (grpcResponseSenders became empty).
    */
-  @volatile var lastAttachedRpcTime: Option[Long] = None
+  @volatile var lastAttachedRpcTimeMs: Option[Long] = None
 
   /** System.currentTimeMillis when this ExecuteHolder was closed. */
-  private var closedTime: Option[Long] = None
+  private var closedTimeMs: Option[Long] = None
 
   /**
    * Attached ExecuteGrpcResponseSenders that send the GRPC responses.
@@ -163,13 +163,13 @@ private[connect] class ExecuteHolder(
 
   private def addGrpcResponseSender(
       sender: ExecuteGrpcResponseSender[proto.ExecutePlanResponse]) = 
synchronized {
-    if (closedTime.isEmpty) {
+    if (closedTimeMs.isEmpty) {
       // Interrupt all other senders - there can be only one active sender.
       // Interrupted senders will remove themselves with 
removeGrpcResponseSender when they exit.
       grpcResponseSenders.foreach(_.interrupt())
       // And add this one.
       grpcResponseSenders += sender
-      lastAttachedRpcTime = None
+      lastAttachedRpcTimeMs = None
     } else {
       // execution is closing... interrupt it already.
       sender.interrupt()
@@ -178,11 +178,11 @@ private[connect] class ExecuteHolder(
 
   def removeGrpcResponseSender(sender: ExecuteGrpcResponseSender[_]): Unit = 
synchronized {
     // if closed, we are shutting down and interrupting all senders already
-    if (closedTime.isEmpty) {
+    if (closedTimeMs.isEmpty) {
       grpcResponseSenders -=
         
sender.asInstanceOf[ExecuteGrpcResponseSender[proto.ExecutePlanResponse]]
       if (grpcResponseSenders.isEmpty) {
-        lastAttachedRpcTime = Some(System.currentTimeMillis())
+        lastAttachedRpcTimeMs = Some(System.currentTimeMillis())
       }
     }
   }
@@ -203,9 +203,9 @@ private[connect] class ExecuteHolder(
    * don't get garbage collected. End this grace period when the initial 
ExecutePlan ends.
    */
   def afterInitialRPC(): Unit = synchronized {
-    if (closedTime.isEmpty) {
+    if (closedTimeMs.isEmpty) {
       if (grpcResponseSenders.isEmpty) {
-        lastAttachedRpcTime = Some(System.currentTimeMillis())
+        lastAttachedRpcTimeMs = Some(System.currentTimeMillis())
       }
     }
   }
@@ -235,7 +235,7 @@ private[connect] class ExecuteHolder(
    * execution from global tracking and from its session.
    */
   def close(): Unit = synchronized {
-    if (closedTime.isEmpty) {
+    if (closedTimeMs.isEmpty) {
       // interrupt execution, if still running.
       runner.interrupt()
       // wait for execution to finish, to make sure no more results get pushed 
to responseObserver
@@ -244,14 +244,14 @@ private[connect] class ExecuteHolder(
       grpcResponseSenders.foreach(_.interrupt())
       // if there were still any grpcResponseSenders, register detach time
       if (grpcResponseSenders.nonEmpty) {
-        lastAttachedRpcTime = Some(System.currentTimeMillis())
+        lastAttachedRpcTimeMs = Some(System.currentTimeMillis())
         grpcResponseSenders.clear()
       }
       // remove all cached responses from observer
       responseObserver.removeAll()
       // post closed to UI
       eventsManager.postClosed()
-      closedTime = Some(System.currentTimeMillis())
+      closedTimeMs = Some(System.currentTimeMillis())
     }
   }
 
@@ -275,9 +275,9 @@ private[connect] class ExecuteHolder(
       sparkSessionTags = sparkSessionTags,
       reattachable = reattachable,
       status = eventsManager.status,
-      creationTime = creationTime,
-      lastAttachedRpcTime = lastAttachedRpcTime,
-      closedTime = closedTime)
+      creationTimeMs = creationTimeMs,
+      lastAttachedRpcTimeMs = lastAttachedRpcTimeMs,
+      closedTimeMs = closedTimeMs)
   }
 
   /** Get key used by SparkConnectExecutionManager global tracker. */
@@ -327,6 +327,9 @@ case class ExecuteInfo(
     sparkSessionTags: Set[String],
     reattachable: Boolean,
     status: ExecuteStatus,
-    creationTime: Long,
-    lastAttachedRpcTime: Option[Long],
-    closedTime: Option[Long])
+    creationTimeMs: Long,
+    lastAttachedRpcTimeMs: Option[Long],
+    closedTimeMs: Option[Long]) {
+
+  def key: ExecuteKey = ExecuteKey(userId, sessionId, operationId)
+}
diff --git 
a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SessionHolder.scala
 
b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SessionHolder.scala
index fd7c10d5c400..0fdf55ff42a0 100644
--- 
a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SessionHolder.scala
+++ 
b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SessionHolder.scala
@@ -20,6 +20,7 @@ package org.apache.spark.sql.connect.service
 import java.nio.file.Path
 import java.util.UUID
 import java.util.concurrent.{ConcurrentHashMap, ConcurrentMap, TimeUnit}
+import javax.annotation.concurrent.GuardedBy
 
 import scala.collection.mutable
 import scala.jdk.CollectionConverters._
@@ -47,9 +48,20 @@ case class SessionKey(userId: String, sessionId: String)
 case class SessionHolder(userId: String, sessionId: String, session: 
SparkSession)
     extends Logging {
 
-  @volatile private var lastRpcAccessTime: Option[Long] = None
+  // Time when the session was started.
+  private val startTimeMs: Long = System.currentTimeMillis()
 
-  @volatile private var isClosing: Boolean = false
+  // Time when the session was last accessed (retrieved from 
SparkConnectSessionManager)
+  @volatile private var lastAccessTimeMs: Long = System.currentTimeMillis()
+
+  // Time when the session was closed.
+  // Set only by close(), and only once.
+  @volatile private var closedTimeMs: Option[Long] = None
+
+  // Custom timeout after a session expires due to inactivity.
+  // Used by SparkConnectSessionManager instead of default timeout if set.
+  // Setting it to -1 indicated forever.
+  @volatile private var customInactiveTimeoutMs: Option[Long] = None
 
   private val executions: ConcurrentMap[String, ExecuteHolder] =
     new ConcurrentHashMap[String, ExecuteHolder]()
@@ -92,8 +104,9 @@ case class SessionHolder(userId: String, sessionId: String, 
session: SparkSessio
    *
    * Called only by SparkConnectExecutionManager under executionsLock.
    */
+  @GuardedBy("SparkConnectService.executionManager.executionsLock")
   private[service] def addExecuteHolder(executeHolder: ExecuteHolder): Unit = {
-    if (isClosing) {
+    if (closedTimeMs.isDefined) {
       // Do not accept new executions if the session is closing.
       throw new SparkSQLException(
         errorClass = "INVALID_HANDLE.SESSION_CLOSED",
@@ -108,7 +121,12 @@ case class SessionHolder(userId: String, sessionId: 
String, session: SparkSessio
     }
   }
 
-  /** Remove ExecuteHolder to this session. Called only by 
SparkConnectExecutionManager. */
+  /**
+   * Remove ExecuteHolder from this session.
+   *
+   * Called only by SparkConnectExecutionManager under executionsLock.
+   */
+  @GuardedBy("SparkConnectService.executionManager.executionsLock")
   private[service] def removeExecuteHolder(operationId: String): Unit = {
     executions.remove(operationId)
   }
@@ -186,7 +204,13 @@ case class SessionHolder(userId: String, sessionId: 
String, session: SparkSessio
   def classloader: ClassLoader = artifactManager.classloader
 
   private[connect] def updateAccessTime(): Unit = {
-    lastRpcAccessTime = Some(System.currentTimeMillis())
+    lastAccessTimeMs = System.currentTimeMillis()
+    logInfo(s"Session $key accessed, time $lastAccessTimeMs.")
+  }
+
+  private[connect] def setCustomInactiveTimeoutMs(newInactiveTimeoutMs: 
Option[Long]): Unit = {
+    customInactiveTimeoutMs = newInactiveTimeoutMs
+    logInfo(s"Session $key inactive timout set to $customInactiveTimeoutMs 
ms.")
   }
 
   /**
@@ -195,24 +219,30 @@ case class SessionHolder(userId: String, sessionId: 
String, session: SparkSessio
    * Called only by SparkConnectSessionManager.
    */
   private[connect] def initializeSession(): Unit = {
-    updateAccessTime()
     eventManager.postStarted()
   }
 
   /**
    * Expire this session and trigger state cleanup mechanisms.
    *
-   * Called only by SparkConnectSessionManager.
+   * Called only by SparkConnectSessionManager.shutdownSessionHolder.
    */
   private[connect] def close(): Unit = {
+    // Called only by SparkConnectSessionManager.shutdownSessionHolder.
+    // It is not called under SparkConnectSessionManager.sessionsLock, but 
it's guaranteed to be
+    // called only once, since removing the session from 
SparkConnectSessionManager.sessionStore is
+    // synchronized and guaranteed to happen only once.
+    if (closedTimeMs.isDefined) {
+      throw new IllegalStateException(s"Session $key is already closed.")
+    }
     logInfo(s"Closing session with userId: $userId and sessionId: $sessionId")
+    closedTimeMs = Some(System.currentTimeMillis())
 
-    // After isClosing=true, SessionHolder.addExecuteHolder() will not allow 
new executions for
-    // this session. Because both SessionHolder.addExecuteHolder() and
-    // SparkConnectExecutionManager.removeAllExecutionsForSession() are 
executed under
-    // executionsLock, this guarantees that removeAllExecutionsForSession 
triggered below will
-    // remove all executions and no new executions will be added in the 
meanwhile.
-    isClosing = true
+    if (eventManager.status == SessionStatus.Pending) {
+      // Testing-only: Some sessions created by SessionHolder.forTesting are 
not fully initialized
+      // and can't be closed.
+      return
+    }
 
     // Note on the below notes about concurrency:
     // While closing the session can potentially race with operations started 
on the session, the
@@ -229,8 +259,12 @@ case class SessionHolder(userId: String, sessionId: 
String, session: SparkSessio
     streamingForeachBatchRunnerCleanerCache.cleanUpAll() // Clean up any 
streaming workers.
     removeAllListeners() // removes all listener and stop python listener 
processes if necessary.
 
-    // Clean up all executions
-    // It is guaranteed at this point that no new addExecuteHolder are getting 
started.
+    // Clean up all executions.
+    // After closedTimeMs is defined, SessionHolder.addExecuteHolder() will 
not allow new executions
+    // to be added for this session anymore. Because both 
SessionHolder.addExecuteHolder() and
+    // SparkConnectExecutionManager.removeAllExecutionsForSession() are 
executed under
+    // executionsLock, this guarantees that removeAllExecutionsForSession 
triggered here will
+    // remove all executions and no new executions will be added in the 
meanwhile.
     
SparkConnectService.executionManager.removeAllExecutionsForSession(this.key)
 
     eventManager.postClosed()
@@ -251,7 +285,14 @@ case class SessionHolder(userId: String, sessionId: 
String, session: SparkSessio
 
   /** Get SessionInfo with information about this SessionHolder. */
   def getSessionHolderInfo: SessionHolderInfo =
-    SessionHolderInfo(userId, sessionId, eventManager.status, 
lastRpcAccessTime)
+    SessionHolderInfo(
+      userId = userId,
+      sessionId = sessionId,
+      status = eventManager.status,
+      startTimeMs = startTimeMs,
+      lastAccessTimeMs = lastAccessTimeMs,
+      customInactiveTimeoutMs = customInactiveTimeoutMs,
+      closedTimeMs = closedTimeMs)
 
   /**
    * Caches given DataFrame with the ID. The cache does not expire. The entry 
needs to be
@@ -350,4 +391,9 @@ case class SessionHolderInfo(
     userId: String,
     sessionId: String,
     status: SessionStatus,
-    lastRpcAccesTime: Option[Long])
+    customInactiveTimeoutMs: Option[Long],
+    startTimeMs: Long,
+    lastAccessTimeMs: Long,
+    closedTimeMs: Option[Long]) {
+  def key: SessionKey = SessionKey(userId, sessionId)
+}
diff --git 
a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectExecutionManager.scala
 
b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectExecutionManager.scala
index 5a9d0136de34..d8d9cee3dad4 100644
--- 
a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectExecutionManager.scala
+++ 
b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectExecutionManager.scala
@@ -18,6 +18,7 @@
 package org.apache.spark.sql.connect.service
 
 import java.util.concurrent.{Executors, ScheduledExecutorService, TimeUnit}
+import javax.annotation.concurrent.GuardedBy
 
 import scala.collection.mutable
 import scala.jdk.CollectionConverters._
@@ -42,6 +43,7 @@ case class ExecuteKey(userId: String, sessionId: String, 
operationId: String)
 private[connect] class SparkConnectExecutionManager() extends Logging {
 
   /** Hash table containing all current executions. Guarded by executionsLock. 
*/
+  @GuardedBy("executionsLock")
   private val executions: mutable.HashMap[ExecuteKey, ExecuteHolder] =
     new mutable.HashMap[ExecuteKey, ExecuteHolder]()
   private val executionsLock = new Object
@@ -53,7 +55,8 @@ private[connect] class SparkConnectExecutionManager() extends 
Logging {
     .build[ExecuteKey, ExecuteInfo]()
 
   /** None if there are no executions. Otherwise, the time when the last 
execution was removed. */
-  private var lastExecutionTime: Option[Long] = 
Some(System.currentTimeMillis())
+  @GuardedBy("executionsLock")
+  private var lastExecutionTimeMs: Option[Long] = 
Some(System.currentTimeMillis())
 
   /** Executor for the periodic maintenance */
   private var scheduledExecutor: Option[ScheduledExecutorService] = None
@@ -82,7 +85,7 @@ private[connect] class SparkConnectExecutionManager() extends 
Logging {
       }
       sessionHolder.addExecuteHolder(executeHolder)
       executions.put(executeHolder.key, executeHolder)
-      lastExecutionTime = None
+      lastExecutionTimeMs = None
       logInfo(s"ExecuteHolder ${executeHolder.key} is created.")
     }
 
@@ -100,18 +103,24 @@ private[connect] class SparkConnectExecutionManager() 
extends Logging {
     executionsLock.synchronized {
       executeHolder = executions.remove(key)
       executeHolder.foreach { e =>
+        // Put into abandonedTombstones under lock, so that if it's accessed 
it will end up
+        // with INVALID_HANDLE.OPERATION_ABANDONED error.
         if (abandoned) {
           abandonedTombstones.put(key, e.getExecuteInfo)
         }
         e.sessionHolder.removeExecuteHolder(e.operationId)
       }
       if (executions.isEmpty) {
-        lastExecutionTime = Some(System.currentTimeMillis())
+        lastExecutionTimeMs = Some(System.currentTimeMillis())
       }
       logInfo(s"ExecuteHolder $key is removed.")
     }
     // close the execution outside the lock
-    executeHolder.foreach(_.close())
+    executeHolder.foreach { e =>
+      e.close()
+      // Update in abandonedTombstones: above it wasn't yet updated with 
closedTime etc.
+      abandonedTombstones.put(key, e.getExecuteInfo)
+    }
   }
 
   private[connect] def getExecuteHolder(key: ExecuteKey): 
Option[ExecuteHolder] = {
@@ -142,7 +151,7 @@ private[connect] class SparkConnectExecutionManager() 
extends Logging {
    */
   def listActiveExecutions: Either[Long, Seq[ExecuteInfo]] = 
executionsLock.synchronized {
     if (executions.isEmpty) {
-      Left(lastExecutionTime.get)
+      Left(lastExecutionTimeMs.get)
     } else {
       Right(executions.values.map(_.getExecuteInfo).toBuffer.toSeq)
     }
@@ -162,10 +171,11 @@ private[connect] class SparkConnectExecutionManager() 
extends Logging {
       executor.awaitTermination(1, TimeUnit.MINUTES)
     }
     scheduledExecutor = None
+    // note: this does not cleanly shut down the executions, but the server is 
shutting down.
     executions.clear()
     abandonedTombstones.invalidateAll()
-    if (lastExecutionTime.isEmpty) {
-      lastExecutionTime = Some(System.currentTimeMillis())
+    if (lastExecutionTimeMs.isEmpty) {
+      lastExecutionTimeMs = Some(System.currentTimeMillis())
     }
   }
 
@@ -175,18 +185,18 @@ private[connect] class SparkConnectExecutionManager() 
extends Logging {
    * removes them after a timeout.
    */
   private def schedulePeriodicChecks(): Unit = executionsLock.synchronized {
-    val interval = 
SparkEnv.get.conf.get(CONNECT_EXECUTE_MANAGER_MAINTENANCE_INTERVAL).toLong
-    val timeout = 
SparkEnv.get.conf.get(CONNECT_EXECUTE_MANAGER_DETACHED_TIMEOUT).toLong
-
     scheduledExecutor match {
       case Some(_) => // Already running.
       case None =>
+        val interval = 
SparkEnv.get.conf.get(CONNECT_EXECUTE_MANAGER_MAINTENANCE_INTERVAL).toLong
         logInfo(s"Starting thread for cleanup of abandoned executions every 
$interval ms")
         scheduledExecutor = Some(Executors.newSingleThreadScheduledExecutor())
         scheduledExecutor.get.scheduleAtFixedRate(
           () => {
-            try periodicMaintenance(timeout)
-            catch {
+            try {
+              val timeout = 
SparkEnv.get.conf.get(CONNECT_EXECUTE_MANAGER_DETACHED_TIMEOUT).toLong
+              periodicMaintenance(timeout)
+            } catch {
               case NonFatal(ex) => logWarning("Unexpected exception in 
periodic task", ex)
             }
           },
@@ -206,7 +216,7 @@ private[connect] class SparkConnectExecutionManager() 
extends Logging {
       val nowMs = System.currentTimeMillis()
 
       executions.values.foreach { executeHolder =>
-        executeHolder.lastAttachedRpcTime match {
+        executeHolder.lastAttachedRpcTimeMs match {
           case Some(detached) =>
             if (detached + timeout <= nowMs) {
               toRemove += executeHolder
@@ -215,13 +225,11 @@ private[connect] class SparkConnectExecutionManager() 
extends Logging {
         }
       }
     }
-    if (!toRemove.isEmpty) {
-      // .. and remove them.
-      toRemove.foreach { executeHolder =>
-        val info = executeHolder.getExecuteInfo
-        logInfo(s"Found execution $info that was abandoned and expired and 
will be removed.")
-        removeExecuteHolder(executeHolder.key, abandoned = true)
-      }
+    // .. and remove them.
+    toRemove.foreach { executeHolder =>
+      val info = executeHolder.getExecuteInfo
+      logInfo(s"Found execution $info that was abandoned and expired and will 
be removed.")
+      removeExecuteHolder(executeHolder.key, abandoned = true)
     }
     logInfo("Finished periodic run of SparkConnectExecutionManager 
maintenance.")
   }
diff --git 
a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectSessionManager.scala
 
b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectSessionManager.scala
index ba402a90a71e..ef14cd305d40 100644
--- 
a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectSessionManager.scala
+++ 
b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectSessionManager.scala
@@ -18,15 +18,19 @@
 package org.apache.spark.sql.connect.service
 
 import java.util.UUID
-import java.util.concurrent.{Callable, TimeUnit}
+import java.util.concurrent.{Executors, ScheduledExecutorService, TimeUnit}
+import javax.annotation.concurrent.GuardedBy
 
-import com.google.common.base.Ticker
-import com.google.common.cache.{CacheBuilder, RemovalListener, 
RemovalNotification}
+import scala.collection.mutable
+import scala.jdk.CollectionConverters._
+import scala.util.control.NonFatal
+
+import com.google.common.cache.CacheBuilder
 
 import org.apache.spark.{SparkEnv, SparkSQLException}
 import org.apache.spark.internal.Logging
 import org.apache.spark.sql.SparkSession
-import 
org.apache.spark.sql.connect.config.Connect.{CONNECT_SESSION_MANAGER_CLOSED_SESSIONS_TOMBSTONES_SIZE,
 CONNECT_SESSION_MANAGER_DEFAULT_SESSION_TIMEOUT}
+import 
org.apache.spark.sql.connect.config.Connect.{CONNECT_SESSION_MANAGER_CLOSED_SESSIONS_TOMBSTONES_SIZE,
 CONNECT_SESSION_MANAGER_DEFAULT_SESSION_TIMEOUT, 
CONNECT_SESSION_MANAGER_MAINTENANCE_INTERVAL}
 
 /**
  * Global tracker of all SessionHolders holding Spark Connect sessions.
@@ -35,15 +39,8 @@ class SparkConnectSessionManager extends Logging {
 
   private val sessionsLock = new Object
 
-  private val sessionStore =
-    CacheBuilder
-      .newBuilder()
-      .ticker(Ticker.systemTicker())
-      .expireAfterAccess(
-        SparkEnv.get.conf.get(CONNECT_SESSION_MANAGER_DEFAULT_SESSION_TIMEOUT),
-        TimeUnit.MILLISECONDS)
-      .removalListener(new RemoveSessionListener)
-      .build[SessionKey, SessionHolder]()
+  @GuardedBy("sessionsLock")
+  private val sessionStore = mutable.HashMap[SessionKey, SessionHolder]()
 
   private val closedSessionsCache =
     CacheBuilder
@@ -51,21 +48,23 @@ class SparkConnectSessionManager extends Logging {
       
.maximumSize(SparkEnv.get.conf.get(CONNECT_SESSION_MANAGER_CLOSED_SESSIONS_TOMBSTONES_SIZE))
       .build[SessionKey, SessionHolderInfo]()
 
+  /** Executor for the periodic maintenance */
+  private var scheduledExecutor: Option[ScheduledExecutorService] = None
+
   /**
    * Based on the userId and sessionId, find or create a new SparkSession.
    */
   private[connect] def getOrCreateIsolatedSession(key: SessionKey): 
SessionHolder = {
-    // Lock to guard against concurrent removal and insertion into 
closedSessionsCache.
-    sessionsLock.synchronized {
-      getSession(
-        key,
-        Some(() => {
-          validateSessionCreate(key)
-          val holder = SessionHolder(key.userId, key.sessionId, 
newIsolatedSession())
-          holder.initializeSession()
-          holder
-        }))
-    }
+    getSession(
+      key,
+      Some(() => {
+        // Executed under sessionsState lock in getSession,  to guard against 
concurrent removal
+        // and insertion into closedSessionsCache.
+        validateSessionCreate(key)
+        val holder = SessionHolder(key.userId, key.sessionId, 
newIsolatedSession())
+        holder.initializeSession()
+        holder
+      }))
   }
 
   /**
@@ -95,47 +94,161 @@ class SparkConnectSessionManager extends Logging {
     Option(getSession(key, None))
   }
 
-  private def getSession(
-      key: SessionKey,
-      default: Option[Callable[SessionHolder]]): SessionHolder = {
-    val session = default match {
-      case Some(callable) => sessionStore.get(key, callable)
-      case None => sessionStore.getIfPresent(key)
+  private def getSession(key: SessionKey, default: Option[() => 
SessionHolder]): SessionHolder = {
+    schedulePeriodicChecks() // Starts the maintenance thread if it hasn't 
started yet.
+
+    sessionsLock.synchronized {
+      // try to get existing session from store
+      val sessionOpt = sessionStore.get(key)
+      // create using default if missing
+      val session = sessionOpt match {
+        case Some(s) => s
+        case None =>
+          default match {
+            case Some(callable) =>
+              val session = callable()
+              sessionStore.put(key, session)
+              session
+            case None =>
+              null
+          }
+      }
+      // record access time before returning
+      session match {
+        case null =>
+          null
+        case s: SessionHolder =>
+          s.updateAccessTime()
+          s
+      }
     }
-    // record access time before returning
-    session match {
-      case null =>
-        null
-      case s: SessionHolder =>
-        s.updateAccessTime()
-        s
+  }
+
+  // Removes session from sessionStore and returns it.
+  private def removeSessionHolder(key: SessionKey): Option[SessionHolder] = {
+    var sessionHolder: Option[SessionHolder] = None
+    sessionsLock.synchronized {
+      sessionHolder = sessionStore.remove(key)
+      sessionHolder.foreach { s =>
+        // Put into closedSessionsCache, so that it cannot get accidentally 
recreated
+        // by getOrCreateIsolatedSession.
+        closedSessionsCache.put(s.key, s.getSessionHolderInfo)
+      }
     }
+    sessionHolder
+  }
+
+  // Shut downs the session after removing.
+  private def shutdownSessionHolder(sessionHolder: SessionHolder): Unit = {
+    sessionHolder.close()
+    // Update in closedSessionsCache: above it wasn't updated with closedTime 
etc. yet.
+    closedSessionsCache.put(sessionHolder.key, 
sessionHolder.getSessionHolderInfo)
   }
 
   def closeSession(key: SessionKey): Unit = {
-    // Invalidate will trigger RemoveSessionListener
-    sessionStore.invalidate(key)
+    val sessionHolder = removeSessionHolder(key)
+    // Rest of the cleanup outside sessionLock - the session cannot be 
accessed anymore by
+    // getOrCreateIsolatedSession.
+    sessionHolder.foreach(shutdownSessionHolder(_))
   }
 
-  private class RemoveSessionListener extends RemovalListener[SessionKey, 
SessionHolder] {
-    override def onRemoval(notification: RemovalNotification[SessionKey, 
SessionHolder]): Unit = {
-      val sessionHolder = notification.getValue
-      sessionsLock.synchronized {
-        // First put into closedSessionsCache, so that it cannot get 
accidentally recreated by
-        // getOrCreateIsolatedSession.
-        closedSessionsCache.put(sessionHolder.key, 
sessionHolder.getSessionHolderInfo)
-      }
-      // Rest of the cleanup outside sessionLock - the session cannot be 
accessed anymore by
-      // getOrCreateIsolatedSession.
-      sessionHolder.close()
+  private[connect] def shutdown(): Unit = sessionsLock.synchronized {
+    scheduledExecutor.foreach { executor =>
+      executor.shutdown()
+      executor.awaitTermination(1, TimeUnit.MINUTES)
     }
+    scheduledExecutor = None
+    // note: this does not cleanly shut down the sessions, but the server is 
shutting down.
+    sessionStore.clear()
+    closedSessionsCache.invalidateAll()
   }
 
-  def shutdown(): Unit = {
+  def listActiveSessions: Seq[SessionHolderInfo] = sessionsLock.synchronized {
+    sessionStore.values.map(_.getSessionHolderInfo).toSeq
+  }
+
+  def listClosedSessions: Seq[SessionHolderInfo] = sessionsLock.synchronized {
+    closedSessionsCache.asMap.asScala.values.toSeq
+  }
+
+  /**
+   * Schedules periodic maintenance checks if it is not already scheduled.
+   *
+   * The checks are looking to remove sessions that expired.
+   */
+  private def schedulePeriodicChecks(): Unit = sessionsLock.synchronized {
+    scheduledExecutor match {
+      case Some(_) => // Already running.
+      case None =>
+        val interval = 
SparkEnv.get.conf.get(CONNECT_SESSION_MANAGER_MAINTENANCE_INTERVAL)
+        logInfo(s"Starting thread for cleanup of expired sessions every 
$interval ms")
+        scheduledExecutor = Some(Executors.newSingleThreadScheduledExecutor())
+        scheduledExecutor.get.scheduleAtFixedRate(
+          () => {
+            try {
+              val defaultInactiveTimeoutMs =
+                
SparkEnv.get.conf.get(CONNECT_SESSION_MANAGER_DEFAULT_SESSION_TIMEOUT)
+              periodicMaintenance(defaultInactiveTimeoutMs)
+            } catch {
+              case NonFatal(ex) => logWarning("Unexpected exception in 
periodic task", ex)
+            }
+          },
+          interval,
+          interval,
+          TimeUnit.MILLISECONDS)
+    }
+  }
+
+  // Visible for testing
+  private[connect] def periodicMaintenance(defaultInactiveTimeoutMs: Long): 
Unit =
+    periodicMaintenance(defaultInactiveTimeoutMs, ignoreCustomTimeout = false)
+
+  // Test only: ignoreCustomTimeout=true is used by invalidateAllSessions to 
force cleanup in tests.
+  private def periodicMaintenance(
+      defaultInactiveTimeoutMs: Long,
+      ignoreCustomTimeout: Boolean): Unit = {
+    logInfo("Started periodic run of SparkConnectSessionManager maintenance.")
+    // Find any sessions that expired and should be removed.
+    val toRemove = new mutable.ArrayBuffer[SessionHolder]()
+
+    def shouldExpire(info: SessionHolderInfo, nowMs: Long): Boolean = {
+      val timeoutMs = if (info.customInactiveTimeoutMs.isDefined && 
!ignoreCustomTimeout) {
+        info.customInactiveTimeoutMs.get
+      } else {
+        defaultInactiveTimeoutMs
+      }
+      // timeout of -1 indicates to never timeout
+      timeoutMs != -1 && info.lastAccessTimeMs + timeoutMs <= nowMs
+    }
+
     sessionsLock.synchronized {
-      sessionStore.invalidateAll()
-      closedSessionsCache.invalidateAll()
+      val nowMs = System.currentTimeMillis()
+      sessionStore.values.foreach { sessionHolder =>
+        if (shouldExpire(sessionHolder.getSessionHolderInfo, nowMs)) {
+          toRemove += sessionHolder
+        }
+      }
+    }
+    // .. and remove them.
+    toRemove.foreach { sessionHolder =>
+      // This doesn't use closeSession to be able to do the extra last chance 
check under lock.
+      val removedSession = sessionsLock.synchronized {
+        // Last chance - check expiration time and remove under lock if 
expired.
+        val info = sessionHolder.getSessionHolderInfo
+        if (shouldExpire(info, System.currentTimeMillis())) {
+          logInfo(s"Found session $info that expired and will be closed.")
+          removeSessionHolder(info.key)
+        } else {
+          None
+        }
+      }
+      // do shutdown and cleanup outside of lock.
+      try removedSession.foreach(shutdownSessionHolder(_))
+      catch {
+        case NonFatal(ex) => logWarning("Unexpected exception closing 
session", ex)
+      }
     }
+    logInfo("Finished periodic run of SparkConnectSessionManager maintenance.")
   }
 
   private def newIsolatedSession(): SparkSession = {
@@ -169,8 +282,9 @@ class SparkConnectSessionManager extends Logging {
   /**
    * Used for testing
    */
-  private[connect] def invalidateAllSessions(): Unit = {
-    sessionStore.invalidateAll()
+  private[connect] def invalidateAllSessions(): Unit = 
sessionsLock.synchronized {
+    periodicMaintenance(defaultInactiveTimeoutMs = 0L, ignoreCustomTimeout = 
true)
+    assert(sessionStore.isEmpty)
     closedSessionsCache.invalidateAll()
   }
 
diff --git 
a/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/SparkConnectServerTest.scala
 
b/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/SparkConnectServerTest.scala
index dbb06437c4d4..b04c42a73078 100644
--- 
a/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/SparkConnectServerTest.scala
+++ 
b/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/SparkConnectServerTest.scala
@@ -156,7 +156,7 @@ trait SparkConnectServerTest extends SharedSparkSession {
       case Right(executions) =>
         // all rpc detached.
         assert(
-          executions.forall(_.lastAttachedRpcTime.isDefined),
+          executions.forall(_.lastAttachedRpcTimeMs.isDefined),
           s"Expected no RPCs, but got $executions")
     }
   }
diff --git 
a/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/service/SparkConnectSessionManagerSuite.scala
 
b/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/service/SparkConnectSessionManagerSuite.scala
new file mode 100644
index 000000000000..fadbd9fa502e
--- /dev/null
+++ 
b/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/service/SparkConnectSessionManagerSuite.scala
@@ -0,0 +1,137 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.connect.service
+
+import java.util.UUID
+
+import org.scalatest.BeforeAndAfterEach
+import org.scalatest.time.SpanSugar._
+
+import org.apache.spark.SparkSQLException
+import org.apache.spark.sql.test.SharedSparkSession
+
+class SparkConnectSessionManagerSuite extends SharedSparkSession with 
BeforeAndAfterEach {
+
+  override def beforeEach(): Unit = {
+    super.beforeEach()
+    SparkConnectService.sessionManager.invalidateAllSessions()
+  }
+
+  test("sessionId needs to be an UUID") {
+    val key = SessionKey("user", "not an uuid")
+    val exGetOrCreate = intercept[SparkSQLException] {
+      SparkConnectService.sessionManager.getOrCreateIsolatedSession(key)
+    }
+    assert(exGetOrCreate.getErrorClass == "INVALID_HANDLE.FORMAT")
+  }
+
+  test(
+    "getOrCreateIsolatedSession/getIsolatedSession/getIsolatedSessionIfPresent 
" +
+      "gets the existing session") {
+    val key = SessionKey("user", UUID.randomUUID().toString)
+    val sessionHolder = 
SparkConnectService.sessionManager.getOrCreateIsolatedSession(key)
+
+    val sessionGetOrCreate =
+      SparkConnectService.sessionManager.getOrCreateIsolatedSession(key)
+    assert(sessionGetOrCreate === sessionHolder)
+
+    val sessionGet = SparkConnectService.sessionManager.getIsolatedSession(key)
+    assert(sessionGet === sessionHolder)
+
+    val sessionGetIfPresent = 
SparkConnectService.sessionManager.getIsolatedSessionIfPresent(key)
+    assert(sessionGetIfPresent.get === sessionHolder)
+  }
+
+  test(
+    "getOrCreateIsolatedSession/getIsolatedSession/getIsolatedSessionIfPresent 
" +
+      "doesn't recreate closed session") {
+    val key = SessionKey("user", UUID.randomUUID().toString)
+    val sessionHolder = 
SparkConnectService.sessionManager.getOrCreateIsolatedSession(key)
+    SparkConnectService.sessionManager.closeSession(key)
+
+    val exGetOrCreate = intercept[SparkSQLException] {
+      SparkConnectService.sessionManager.getOrCreateIsolatedSession(key)
+    }
+    assert(exGetOrCreate.getErrorClass == "INVALID_HANDLE.SESSION_CLOSED")
+
+    val exGet = intercept[SparkSQLException] {
+      SparkConnectService.sessionManager.getIsolatedSession(key)
+    }
+    assert(exGet.getErrorClass == "INVALID_HANDLE.SESSION_CLOSED")
+
+    val sessionGetIfPresent = 
SparkConnectService.sessionManager.getIsolatedSessionIfPresent(key)
+    assert(sessionGetIfPresent.isEmpty)
+  }
+
+  test("getIsolatedSession/getIsolatedSessionIfPresent when session doesn't 
exist") {
+    val key = SessionKey("user", UUID.randomUUID().toString)
+
+    val exGet = intercept[SparkSQLException] {
+      SparkConnectService.sessionManager.getIsolatedSession(key)
+    }
+    assert(exGet.getErrorClass == "INVALID_HANDLE.SESSION_NOT_FOUND")
+
+    val sessionGetIfPresent = 
SparkConnectService.sessionManager.getIsolatedSessionIfPresent(key)
+    assert(sessionGetIfPresent.isEmpty)
+  }
+
+  test("SessionHolder with custom expiration time is not cleaned up due to 
inactivity") {
+    val key = SessionKey("user", UUID.randomUUID().toString)
+    val sessionHolder = 
SparkConnectService.sessionManager.getOrCreateIsolatedSession(key)
+
+    assert(
+      SparkConnectService.sessionManager.listActiveSessions.exists(
+        _.sessionId == sessionHolder.sessionId))
+    sessionHolder.setCustomInactiveTimeoutMs(Some(5.days.toMillis))
+
+    // clean up with inactivity timeout of 0.
+    
SparkConnectService.sessionManager.periodicMaintenance(defaultInactiveTimeoutMs 
= 0L)
+    // session should still be there.
+    assert(
+      SparkConnectService.sessionManager.listActiveSessions.exists(
+        _.sessionId == sessionHolder.sessionId))
+
+    sessionHolder.setCustomInactiveTimeoutMs(None)
+    // it will be cleaned up now.
+    
SparkConnectService.sessionManager.periodicMaintenance(defaultInactiveTimeoutMs 
= 0L)
+    assert(SparkConnectService.sessionManager.listActiveSessions.isEmpty)
+    assert(
+      SparkConnectService.sessionManager.listClosedSessions.exists(
+        _.sessionId == sessionHolder.sessionId))
+  }
+
+  test("SessionHolder is recorded with status closed after close") {
+    val key = SessionKey("user", UUID.randomUUID().toString)
+    val sessionHolder = 
SparkConnectService.sessionManager.getOrCreateIsolatedSession(key)
+
+    val activeSessionInfo = 
SparkConnectService.sessionManager.listActiveSessions.find(
+      _.sessionId == sessionHolder.sessionId)
+    assert(activeSessionInfo.isDefined)
+    assert(activeSessionInfo.get.status == SessionStatus.Started)
+    assert(activeSessionInfo.get.closedTimeMs.isEmpty)
+
+    SparkConnectService.sessionManager.closeSession(sessionHolder.key)
+
+    assert(SparkConnectService.sessionManager.listActiveSessions.isEmpty)
+    val closedSessionInfo = 
SparkConnectService.sessionManager.listClosedSessions.find(
+      _.sessionId == sessionHolder.sessionId)
+    assert(closedSessionInfo.isDefined)
+    assert(closedSessionInfo.get.status == SessionStatus.Closed)
+    assert(closedSessionInfo.get.closedTimeMs.isDefined)
+  }
+}


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org


Reply via email to