This is an automated email from the ASF dual-hosted git repository.

gurwls223 pushed a commit to branch branch-3.5
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/branch-3.5 by this push:
     new cbcd85c5e57 [SPARK-44425][CONNECT] Validate that user provided 
sessionId is an UUID
cbcd85c5e57 is described below

commit cbcd85c5e57695f3992eaf694d61be86f84449c3
Author: Juliusz Sompolski <ju...@databricks.com>
AuthorDate: Fri Jul 28 07:55:02 2023 +0900

    [SPARK-44425][CONNECT] Validate that user provided sessionId is an UUID
    
    We want to validate that user provided sessionId is an UUID. Existing Spark 
Connect python and scala clients already do that, we would like to depend on it 
being in this format moving forward, just like we already validate that 
operatoinId is an UUID.
    
    Validate what's already assumed.
    
    No.
    
    Existing CI.
    
    Closes #42150 from juliuszsompolski/SPARK-44425.
    
    Authored-by: Juliusz Sompolski <ju...@databricks.com>
    Signed-off-by: Hyukjin Kwon <gurwls...@apache.org>
    (cherry picked from commit a3bd477a6d8c317ee1e9a6aae6ebd2ef4fc67cce)
    Signed-off-by: Hyukjin Kwon <gurwls...@apache.org>
---
 .../src/main/protobuf/spark/connect/base.proto     |  6 ++++++
 .../sql/connect/service/SparkConnectService.scala  | 12 ++++++++++-
 .../connect/artifact/ArtifactManagerSuite.scala    | 15 +++++++++-----
 .../connect/planner/SparkConnectServiceSuite.scala | 23 +++++++++++++---------
 .../connect/service/AddArtifactsHandlerSuite.scala | 13 +++++++-----
 python/pyspark/sql/connect/proto/base_pb2.pyi      |  6 ++++++
 6 files changed, 55 insertions(+), 20 deletions(-)

diff --git 
a/connector/connect/common/src/main/protobuf/spark/connect/base.proto 
b/connector/connect/common/src/main/protobuf/spark/connect/base.proto
index d935ae65328..21fd167f6b5 100644
--- a/connector/connect/common/src/main/protobuf/spark/connect/base.proto
+++ b/connector/connect/common/src/main/protobuf/spark/connect/base.proto
@@ -63,6 +63,7 @@ message AnalyzePlanRequest {
   // The session_id specifies a spark session for a user id (which is specified
   // by user_context.user_id). The session_id is set by the client to be able 
to
   // collate streaming responses from different queries within the dedicated 
session.
+  // The id should be an UUID string of the format 
`00112233-4455-6677-8899-aabbccddeeff`
   string session_id = 1;
 
   // (Required) User context
@@ -273,6 +274,7 @@ message ExecutePlanRequest {
   // The session_id specifies a spark session for a user id (which is specified
   // by user_context.user_id). The session_id is set by the client to be able 
to
   // collate streaming responses from different queries within the dedicated 
session.
+  // The id should be an UUID string of the format 
`00112233-4455-6677-8899-aabbccddeeff`
   string session_id = 1;
 
   // (Required) User context
@@ -407,6 +409,7 @@ message ConfigRequest {
   // The session_id specifies a spark session for a user id (which is specified
   // by user_context.user_id). The session_id is set by the client to be able 
to
   // collate streaming responses from different queries within the dedicated 
session.
+  // The id should be an UUID string of the format 
`00112233-4455-6677-8899-aabbccddeeff`
   string session_id = 1;
 
   // (Required) User context
@@ -492,6 +495,7 @@ message AddArtifactsRequest {
   // The session_id specifies a spark session for a user id (which is specified
   // by user_context.user_id). The session_id is set by the client to be able 
to
   // collate streaming responses from different queries within the dedicated 
session.
+  // The id should be an UUID string of the format 
`00112233-4455-6677-8899-aabbccddeeff`
   string session_id = 1;
 
   // User context
@@ -581,6 +585,7 @@ message ArtifactStatusesRequest {
   // The session_id specifies a spark session for a user id (which is specified
   // by user_context.user_id). The session_id is set by the client to be able 
to
   // collate streaming responses from different queries within the dedicated 
session.
+  // The id should be an UUID string of the format 
`00112233-4455-6677-8899-aabbccddeeff`
   string session_id = 1;
 
   // User context
@@ -617,6 +622,7 @@ message InterruptRequest {
   // The session_id specifies a spark session for a user id (which is specified
   // by user_context.user_id). The session_id is set by the client to be able 
to
   // collate streaming responses from different queries within the dedicated 
session.
+  // The id should be an UUID string of the format 
`00112233-4455-6677-8899-aabbccddeeff`
   string session_id = 1;
 
   // (Required) User context
diff --git 
a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectService.scala
 
b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectService.scala
index ad40c94d549..c8fbfca6f70 100644
--- 
a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectService.scala
+++ 
b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectService.scala
@@ -17,6 +17,7 @@
 
 package org.apache.spark.sql.connect.service
 
+import java.util.UUID
 import java.util.concurrent.TimeUnit
 
 import com.google.common.base.Ticker
@@ -27,7 +28,7 @@ import io.grpc.protobuf.services.ProtoReflectionService
 import io.grpc.stub.StreamObserver
 import org.apache.commons.lang3.StringUtils
 
-import org.apache.spark.SparkEnv
+import org.apache.spark.{SparkEnv, SparkSQLException}
 import org.apache.spark.connect.proto
 import org.apache.spark.connect.proto.{AddArtifactsRequest, 
AddArtifactsResponse}
 import org.apache.spark.internal.Logging
@@ -220,6 +221,15 @@ object SparkConnectService {
    * Based on the `key` find or create a new SparkSession.
    */
   def getOrCreateIsolatedSession(userId: String, sessionId: String): 
SessionHolder = {
+    // Validate that sessionId is formatted like UUID before creating session.
+    try {
+      UUID.fromString(sessionId).toString
+    } catch {
+      case _: IllegalArgumentException =>
+        throw new SparkSQLException(
+          errorClass = "INVALID_HANDLE.FORMAT",
+          messageParameters = Map("handle" -> sessionId))
+    }
     userSessionMapping.get(
       (userId, sessionId),
       () => {
diff --git 
a/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/artifact/ArtifactManagerSuite.scala
 
b/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/artifact/ArtifactManagerSuite.scala
index 199290327cf..fa3b7d52379 100644
--- 
a/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/artifact/ArtifactManagerSuite.scala
+++ 
b/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/artifact/ArtifactManagerSuite.scala
@@ -18,6 +18,7 @@ package org.apache.spark.sql.connect.artifact
 
 import java.nio.charset.StandardCharsets
 import java.nio.file.{Files, Paths}
+import java.util.UUID
 
 import org.apache.commons.io.FileUtils
 
@@ -96,7 +97,8 @@ class ArtifactManagerSuite extends SharedSparkSession with 
ResourceHelper {
     val remotePath = Paths.get("classes/Hello.class")
     assert(stagingPath.toFile.exists())
 
-    val sessionHolder = SparkConnectService.getOrCreateIsolatedSession("c1", 
"session")
+    val sessionHolder =
+      SparkConnectService.getOrCreateIsolatedSession("c1", 
UUID.randomUUID.toString())
     sessionHolder.addArtifact(remotePath, stagingPath, None)
 
     val movedClassFile = SparkConnectArtifactManager.artifactRootPath
@@ -208,9 +210,11 @@ class ArtifactManagerSuite extends SharedSparkSession with 
ResourceHelper {
   }
 
   test("Classloaders for spark sessions are isolated") {
-    val holder1 = SparkConnectService.getOrCreateIsolatedSession("c1", 
"session1")
-    val holder2 = SparkConnectService.getOrCreateIsolatedSession("c2", 
"session2")
-    val holder3 = SparkConnectService.getOrCreateIsolatedSession("c3", 
"session3")
+    // use same sessionId - different users should still make it isolated.
+    val sessionId = UUID.randomUUID.toString()
+    val holder1 = SparkConnectService.getOrCreateIsolatedSession("c1", 
sessionId)
+    val holder2 = SparkConnectService.getOrCreateIsolatedSession("c2", 
sessionId)
+    val holder3 = SparkConnectService.getOrCreateIsolatedSession("c3", 
sessionId)
 
     def addHelloClass(holder: SessionHolder): Unit = {
       val copyDir = Utils.createTempDir().toPath
@@ -267,7 +271,8 @@ class ArtifactManagerSuite extends SharedSparkSession with 
ResourceHelper {
     val stagingPath = copyDir.resolve("Hello.class")
     val remotePath = Paths.get("classes/Hello.class")
 
-    val sessionHolder = SparkConnectService.getOrCreateIsolatedSession("c1", 
"session")
+    val sessionHolder =
+      SparkConnectService.getOrCreateIsolatedSession("c1", 
UUID.randomUUID.toString)
     sessionHolder.addArtifact(remotePath, stagingPath, None)
 
     val sessionDirectory =
diff --git 
a/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectServiceSuite.scala
 
b/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectServiceSuite.scala
index d820e65a685..c29a9b9b629 100644
--- 
a/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectServiceSuite.scala
+++ 
b/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectServiceSuite.scala
@@ -167,6 +167,7 @@ class SparkConnectServiceSuite extends SharedSparkSession 
with MockitoSugar with
         .newBuilder()
         .setPlan(plan)
         .setUserContext(context)
+        .setSessionId(UUID.randomUUID.toString())
         .build()
 
       // Execute plan.
@@ -334,7 +335,8 @@ class SparkConnectServiceSuite extends SharedSparkSession 
with MockitoSugar with
                 .setCommand(ByteString.copyFrom("command".getBytes()))
                 .setPythonVer("3.10")
                 .build())))) { command =>
-    withCommandTest { verifyEvents =>
+    val sessionId = UUID.randomUUID.toString()
+    withCommandTest(sessionId) { verifyEvents =>
       val instance = new SparkConnectService(false)
       val context = proto.UserContext
         .newBuilder()
@@ -348,7 +350,7 @@ class SparkConnectServiceSuite extends SharedSparkSession 
with MockitoSugar with
       val request = proto.ExecutePlanRequest
         .newBuilder()
         .setPlan(plan)
-        .setSessionId("s1")
+        .setSessionId(sessionId)
         .setUserContext(context)
         .build()
 
@@ -393,11 +395,12 @@ class SparkConnectServiceSuite extends SharedSparkSession 
with MockitoSugar with
   }
 
   test("SPARK-43923: canceled request send events") {
+    val sessionId = UUID.randomUUID.toString
     withEvents { verifyEvents =>
       val instance = new SparkConnectService(false)
 
       // Add an always crashing UDF
-      val session = SparkConnectService.getOrCreateIsolatedSession("c1", 
"session").session
+      val session = SparkConnectService.getOrCreateIsolatedSession("c1", 
sessionId).session
       val sleep: Long => Long = { time =>
         Thread.sleep(time)
         time
@@ -417,7 +420,7 @@ class SparkConnectServiceSuite extends SharedSparkSession 
with MockitoSugar with
         .newBuilder()
         .setPlan(plan)
         .setUserContext(context)
-        .setSessionId("session")
+        .setSessionId(sessionId)
         .build()
 
       val thread = new Thread {
@@ -426,7 +429,7 @@ class SparkConnectServiceSuite extends SharedSparkSession 
with MockitoSugar with
           instance.interrupt(
             proto.InterruptRequest
               .newBuilder()
-              .setSessionId("session")
+              .setSessionId(sessionId)
               .setUserContext(context)
               
.setInterruptType(proto.InterruptRequest.InterruptType.INTERRUPT_TYPE_ALL)
               .build(),
@@ -463,11 +466,12 @@ class SparkConnectServiceSuite extends SharedSparkSession 
with MockitoSugar with
   }
 
   test("SPARK-41165: failures in the arrow collect path should not cause 
hangs") {
+    val sessionId = UUID.randomUUID.toString
     withEvents { verifyEvents =>
       val instance = new SparkConnectService(false)
 
       // Add an always crashing UDF
-      val session = SparkConnectService.getOrCreateIsolatedSession("c1", 
"session").session
+      val session = SparkConnectService.getOrCreateIsolatedSession("c1", 
sessionId).session
       val instaKill: Long => Long = { _ =>
         throw new Exception("Kaboom")
       }
@@ -486,7 +490,7 @@ class SparkConnectServiceSuite extends SharedSparkSession 
with MockitoSugar with
         .newBuilder()
         .setPlan(plan)
         .setUserContext(context)
-        .setSessionId("session")
+        .setSessionId(sessionId)
         .build()
 
       // The observer is executed inside this thread. So
@@ -599,6 +603,7 @@ class SparkConnectServiceSuite extends SharedSparkSession 
with MockitoSugar with
         .newBuilder()
         .setPlan(plan)
         .setUserContext(context)
+        .setSessionId(UUID.randomUUID.toString())
         .build()
 
       // Execute plan.
@@ -637,7 +642,7 @@ class SparkConnectServiceSuite extends SharedSparkSession 
with MockitoSugar with
     }
   }
 
-  protected def withCommandTest(f: VerifyEvents => Unit): Unit = {
+  protected def withCommandTest(sessionId: String)(f: VerifyEvents => Unit): 
Unit = {
     withView("testview") {
       withTable("testcat.testtable") {
         withSparkConf(
@@ -649,7 +654,7 @@ class SparkConnectServiceSuite extends SharedSparkSession 
with MockitoSugar with
             when(restartedQuery.id).thenReturn(DEFAULT_UUID)
             when(restartedQuery.runId).thenReturn(DEFAULT_UUID)
             
SparkConnectService.streamingSessionManager.registerNewStreamingQuery(
-              SparkConnectService.getOrCreateIsolatedSession("c1", "s1"),
+              SparkConnectService.getOrCreateIsolatedSession("c1", sessionId),
               restartedQuery)
             f(verifyEvents)
           }
diff --git 
a/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/service/AddArtifactsHandlerSuite.scala
 
b/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/service/AddArtifactsHandlerSuite.scala
index f11c9b2969e..2e199bff5e7 100644
--- 
a/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/service/AddArtifactsHandlerSuite.scala
+++ 
b/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/service/AddArtifactsHandlerSuite.scala
@@ -18,6 +18,7 @@ package org.apache.spark.sql.connect.service
 
 import java.io.InputStream
 import java.nio.file.{Files, Path}
+import java.util.UUID
 
 import scala.collection.JavaConverters._
 import scala.collection.mutable
@@ -37,6 +38,8 @@ class AddArtifactsHandlerSuite extends SharedSparkSession 
with ResourceHelper {
 
   private val CHUNK_SIZE: Int = 32 * 1024
 
+  private val sessionId = UUID.randomUUID.toString()
+
   class DummyStreamObserver(p: Promise[AddArtifactsResponse])
       extends StreamObserver[AddArtifactsResponse] {
     override def onNext(v: AddArtifactsResponse): Unit = p.success(v)
@@ -125,7 +128,7 @@ class AddArtifactsHandlerSuite extends SharedSparkSession 
with ResourceHelper {
 
     val singleChunkArtifactRequest = AddArtifactsRequest
       .newBuilder()
-      .setSessionId("abc")
+      .setSessionId(sessionId)
       .setUserContext(context)
       .setBatch(
         
proto.AddArtifactsRequest.Batch.newBuilder().addArtifacts(singleChunkArtifact).build())
@@ -168,7 +171,7 @@ class AddArtifactsHandlerSuite extends SharedSparkSession 
with ResourceHelper {
 
     val requestBuilder = AddArtifactsRequest
       .newBuilder()
-      .setSessionId("abc")
+      .setSessionId(sessionId)
       .setUserContext(context)
       .setBeginChunk(beginChunkedArtifact)
 
@@ -295,7 +298,7 @@ class AddArtifactsHandlerSuite extends SharedSparkSession 
with ResourceHelper {
 
       val singleChunkArtifactRequest = AddArtifactsRequest
         .newBuilder()
-        .setSessionId("abc")
+        .setSessionId(sessionId)
         .setUserContext(context)
         .setBatch(
           
proto.AddArtifactsRequest.Batch.newBuilder().addArtifacts(singleChunkArtifact).build())
@@ -336,7 +339,7 @@ class AddArtifactsHandlerSuite extends SharedSparkSession 
with ResourceHelper {
 
     val singleChunkArtifactRequest = AddArtifactsRequest
       .newBuilder()
-      .setSessionId("abc")
+      .setSessionId(sessionId)
       .setUserContext(context)
       .setBatch(
         
proto.AddArtifactsRequest.Batch.newBuilder().addArtifacts(singleChunkArtifact).build())
@@ -353,7 +356,7 @@ class AddArtifactsHandlerSuite extends SharedSparkSession 
with ResourceHelper {
 
     val beginChunkArtifactRequest = AddArtifactsRequest
       .newBuilder()
-      .setSessionId("abc")
+      .setSessionId(sessionId)
       .setUserContext(context)
       .setBeginChunk(beginChunkedArtifact)
       .build()
diff --git a/python/pyspark/sql/connect/proto/base_pb2.pyi 
b/python/pyspark/sql/connect/proto/base_pb2.pyi
index 651438ea438..6059d38bd19 100644
--- a/python/pyspark/sql/connect/proto/base_pb2.pyi
+++ b/python/pyspark/sql/connect/proto/base_pb2.pyi
@@ -499,6 +499,7 @@ class AnalyzePlanRequest(google.protobuf.message.Message):
     The session_id specifies a spark session for a user id (which is specified
     by user_context.user_id). The session_id is set by the client to be able to
     collate streaming responses from different queries within the dedicated 
session.
+    The id should be an UUID string of the format 
`00112233-4455-6677-8899-aabbccddeeff`
     """
     @property
     def user_context(self) -> global___UserContext:
@@ -1042,6 +1043,7 @@ class ExecutePlanRequest(google.protobuf.message.Message):
     The session_id specifies a spark session for a user id (which is specified
     by user_context.user_id). The session_id is set by the client to be able to
     collate streaming responses from different queries within the dedicated 
session.
+    The id should be an UUID string of the format 
`00112233-4455-6677-8899-aabbccddeeff`
     """
     @property
     def user_context(self) -> global___UserContext:
@@ -1722,6 +1724,7 @@ class ConfigRequest(google.protobuf.message.Message):
     The session_id specifies a spark session for a user id (which is specified
     by user_context.user_id). The session_id is set by the client to be able to
     collate streaming responses from different queries within the dedicated 
session.
+    The id should be an UUID string of the format 
`00112233-4455-6677-8899-aabbccddeeff`
     """
     @property
     def user_context(self) -> global___UserContext:
@@ -1961,6 +1964,7 @@ class 
AddArtifactsRequest(google.protobuf.message.Message):
     The session_id specifies a spark session for a user id (which is specified
     by user_context.user_id). The session_id is set by the client to be able to
     collate streaming responses from different queries within the dedicated 
session.
+    The id should be an UUID string of the format 
`00112233-4455-6677-8899-aabbccddeeff`
     """
     @property
     def user_context(self) -> global___UserContext:
@@ -2112,6 +2116,7 @@ class 
ArtifactStatusesRequest(google.protobuf.message.Message):
     The session_id specifies a spark session for a user id (which is specified
     by user_context.user_id). The session_id is set by the client to be able to
     collate streaming responses from different queries within the dedicated 
session.
+    The id should be an UUID string of the format 
`00112233-4455-6677-8899-aabbccddeeff`
     """
     @property
     def user_context(self) -> global___UserContext:
@@ -2278,6 +2283,7 @@ class InterruptRequest(google.protobuf.message.Message):
     The session_id specifies a spark session for a user id (which is specified
     by user_context.user_id). The session_id is set by the client to be able to
     collate streaming responses from different queries within the dedicated 
session.
+    The id should be an UUID string of the format 
`00112233-4455-6677-8899-aabbccddeeff`
     """
     @property
     def user_context(self) -> global___UserContext:


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to