This is an automated email from the ASF dual-hosted git repository. gurwls223 pushed a commit to branch branch-3.5 in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/branch-3.5 by this push: new cbcd85c5e57 [SPARK-44425][CONNECT] Validate that user provided sessionId is an UUID cbcd85c5e57 is described below commit cbcd85c5e57695f3992eaf694d61be86f84449c3 Author: Juliusz Sompolski <ju...@databricks.com> AuthorDate: Fri Jul 28 07:55:02 2023 +0900 [SPARK-44425][CONNECT] Validate that user provided sessionId is an UUID We want to validate that user provided sessionId is an UUID. Existing Spark Connect python and scala clients already do that, we would like to depend on it being in this format moving forward, just like we already validate that operatoinId is an UUID. Validate what's already assumed. No. Existing CI. Closes #42150 from juliuszsompolski/SPARK-44425. Authored-by: Juliusz Sompolski <ju...@databricks.com> Signed-off-by: Hyukjin Kwon <gurwls...@apache.org> (cherry picked from commit a3bd477a6d8c317ee1e9a6aae6ebd2ef4fc67cce) Signed-off-by: Hyukjin Kwon <gurwls...@apache.org> --- .../src/main/protobuf/spark/connect/base.proto | 6 ++++++ .../sql/connect/service/SparkConnectService.scala | 12 ++++++++++- .../connect/artifact/ArtifactManagerSuite.scala | 15 +++++++++----- .../connect/planner/SparkConnectServiceSuite.scala | 23 +++++++++++++--------- .../connect/service/AddArtifactsHandlerSuite.scala | 13 +++++++----- python/pyspark/sql/connect/proto/base_pb2.pyi | 6 ++++++ 6 files changed, 55 insertions(+), 20 deletions(-) diff --git a/connector/connect/common/src/main/protobuf/spark/connect/base.proto b/connector/connect/common/src/main/protobuf/spark/connect/base.proto index d935ae65328..21fd167f6b5 100644 --- a/connector/connect/common/src/main/protobuf/spark/connect/base.proto +++ b/connector/connect/common/src/main/protobuf/spark/connect/base.proto @@ -63,6 +63,7 @@ message AnalyzePlanRequest { // The session_id specifies a spark session for a user id (which is specified // by user_context.user_id). The session_id is set by the client to be able to // collate streaming responses from different queries within the dedicated session. + // The id should be an UUID string of the format `00112233-4455-6677-8899-aabbccddeeff` string session_id = 1; // (Required) User context @@ -273,6 +274,7 @@ message ExecutePlanRequest { // The session_id specifies a spark session for a user id (which is specified // by user_context.user_id). The session_id is set by the client to be able to // collate streaming responses from different queries within the dedicated session. + // The id should be an UUID string of the format `00112233-4455-6677-8899-aabbccddeeff` string session_id = 1; // (Required) User context @@ -407,6 +409,7 @@ message ConfigRequest { // The session_id specifies a spark session for a user id (which is specified // by user_context.user_id). The session_id is set by the client to be able to // collate streaming responses from different queries within the dedicated session. + // The id should be an UUID string of the format `00112233-4455-6677-8899-aabbccddeeff` string session_id = 1; // (Required) User context @@ -492,6 +495,7 @@ message AddArtifactsRequest { // The session_id specifies a spark session for a user id (which is specified // by user_context.user_id). The session_id is set by the client to be able to // collate streaming responses from different queries within the dedicated session. + // The id should be an UUID string of the format `00112233-4455-6677-8899-aabbccddeeff` string session_id = 1; // User context @@ -581,6 +585,7 @@ message ArtifactStatusesRequest { // The session_id specifies a spark session for a user id (which is specified // by user_context.user_id). The session_id is set by the client to be able to // collate streaming responses from different queries within the dedicated session. + // The id should be an UUID string of the format `00112233-4455-6677-8899-aabbccddeeff` string session_id = 1; // User context @@ -617,6 +622,7 @@ message InterruptRequest { // The session_id specifies a spark session for a user id (which is specified // by user_context.user_id). The session_id is set by the client to be able to // collate streaming responses from different queries within the dedicated session. + // The id should be an UUID string of the format `00112233-4455-6677-8899-aabbccddeeff` string session_id = 1; // (Required) User context diff --git a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectService.scala b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectService.scala index ad40c94d549..c8fbfca6f70 100644 --- a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectService.scala +++ b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectService.scala @@ -17,6 +17,7 @@ package org.apache.spark.sql.connect.service +import java.util.UUID import java.util.concurrent.TimeUnit import com.google.common.base.Ticker @@ -27,7 +28,7 @@ import io.grpc.protobuf.services.ProtoReflectionService import io.grpc.stub.StreamObserver import org.apache.commons.lang3.StringUtils -import org.apache.spark.SparkEnv +import org.apache.spark.{SparkEnv, SparkSQLException} import org.apache.spark.connect.proto import org.apache.spark.connect.proto.{AddArtifactsRequest, AddArtifactsResponse} import org.apache.spark.internal.Logging @@ -220,6 +221,15 @@ object SparkConnectService { * Based on the `key` find or create a new SparkSession. */ def getOrCreateIsolatedSession(userId: String, sessionId: String): SessionHolder = { + // Validate that sessionId is formatted like UUID before creating session. + try { + UUID.fromString(sessionId).toString + } catch { + case _: IllegalArgumentException => + throw new SparkSQLException( + errorClass = "INVALID_HANDLE.FORMAT", + messageParameters = Map("handle" -> sessionId)) + } userSessionMapping.get( (userId, sessionId), () => { diff --git a/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/artifact/ArtifactManagerSuite.scala b/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/artifact/ArtifactManagerSuite.scala index 199290327cf..fa3b7d52379 100644 --- a/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/artifact/ArtifactManagerSuite.scala +++ b/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/artifact/ArtifactManagerSuite.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql.connect.artifact import java.nio.charset.StandardCharsets import java.nio.file.{Files, Paths} +import java.util.UUID import org.apache.commons.io.FileUtils @@ -96,7 +97,8 @@ class ArtifactManagerSuite extends SharedSparkSession with ResourceHelper { val remotePath = Paths.get("classes/Hello.class") assert(stagingPath.toFile.exists()) - val sessionHolder = SparkConnectService.getOrCreateIsolatedSession("c1", "session") + val sessionHolder = + SparkConnectService.getOrCreateIsolatedSession("c1", UUID.randomUUID.toString()) sessionHolder.addArtifact(remotePath, stagingPath, None) val movedClassFile = SparkConnectArtifactManager.artifactRootPath @@ -208,9 +210,11 @@ class ArtifactManagerSuite extends SharedSparkSession with ResourceHelper { } test("Classloaders for spark sessions are isolated") { - val holder1 = SparkConnectService.getOrCreateIsolatedSession("c1", "session1") - val holder2 = SparkConnectService.getOrCreateIsolatedSession("c2", "session2") - val holder3 = SparkConnectService.getOrCreateIsolatedSession("c3", "session3") + // use same sessionId - different users should still make it isolated. + val sessionId = UUID.randomUUID.toString() + val holder1 = SparkConnectService.getOrCreateIsolatedSession("c1", sessionId) + val holder2 = SparkConnectService.getOrCreateIsolatedSession("c2", sessionId) + val holder3 = SparkConnectService.getOrCreateIsolatedSession("c3", sessionId) def addHelloClass(holder: SessionHolder): Unit = { val copyDir = Utils.createTempDir().toPath @@ -267,7 +271,8 @@ class ArtifactManagerSuite extends SharedSparkSession with ResourceHelper { val stagingPath = copyDir.resolve("Hello.class") val remotePath = Paths.get("classes/Hello.class") - val sessionHolder = SparkConnectService.getOrCreateIsolatedSession("c1", "session") + val sessionHolder = + SparkConnectService.getOrCreateIsolatedSession("c1", UUID.randomUUID.toString) sessionHolder.addArtifact(remotePath, stagingPath, None) val sessionDirectory = diff --git a/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectServiceSuite.scala b/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectServiceSuite.scala index d820e65a685..c29a9b9b629 100644 --- a/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectServiceSuite.scala +++ b/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectServiceSuite.scala @@ -167,6 +167,7 @@ class SparkConnectServiceSuite extends SharedSparkSession with MockitoSugar with .newBuilder() .setPlan(plan) .setUserContext(context) + .setSessionId(UUID.randomUUID.toString()) .build() // Execute plan. @@ -334,7 +335,8 @@ class SparkConnectServiceSuite extends SharedSparkSession with MockitoSugar with .setCommand(ByteString.copyFrom("command".getBytes())) .setPythonVer("3.10") .build())))) { command => - withCommandTest { verifyEvents => + val sessionId = UUID.randomUUID.toString() + withCommandTest(sessionId) { verifyEvents => val instance = new SparkConnectService(false) val context = proto.UserContext .newBuilder() @@ -348,7 +350,7 @@ class SparkConnectServiceSuite extends SharedSparkSession with MockitoSugar with val request = proto.ExecutePlanRequest .newBuilder() .setPlan(plan) - .setSessionId("s1") + .setSessionId(sessionId) .setUserContext(context) .build() @@ -393,11 +395,12 @@ class SparkConnectServiceSuite extends SharedSparkSession with MockitoSugar with } test("SPARK-43923: canceled request send events") { + val sessionId = UUID.randomUUID.toString withEvents { verifyEvents => val instance = new SparkConnectService(false) // Add an always crashing UDF - val session = SparkConnectService.getOrCreateIsolatedSession("c1", "session").session + val session = SparkConnectService.getOrCreateIsolatedSession("c1", sessionId).session val sleep: Long => Long = { time => Thread.sleep(time) time @@ -417,7 +420,7 @@ class SparkConnectServiceSuite extends SharedSparkSession with MockitoSugar with .newBuilder() .setPlan(plan) .setUserContext(context) - .setSessionId("session") + .setSessionId(sessionId) .build() val thread = new Thread { @@ -426,7 +429,7 @@ class SparkConnectServiceSuite extends SharedSparkSession with MockitoSugar with instance.interrupt( proto.InterruptRequest .newBuilder() - .setSessionId("session") + .setSessionId(sessionId) .setUserContext(context) .setInterruptType(proto.InterruptRequest.InterruptType.INTERRUPT_TYPE_ALL) .build(), @@ -463,11 +466,12 @@ class SparkConnectServiceSuite extends SharedSparkSession with MockitoSugar with } test("SPARK-41165: failures in the arrow collect path should not cause hangs") { + val sessionId = UUID.randomUUID.toString withEvents { verifyEvents => val instance = new SparkConnectService(false) // Add an always crashing UDF - val session = SparkConnectService.getOrCreateIsolatedSession("c1", "session").session + val session = SparkConnectService.getOrCreateIsolatedSession("c1", sessionId).session val instaKill: Long => Long = { _ => throw new Exception("Kaboom") } @@ -486,7 +490,7 @@ class SparkConnectServiceSuite extends SharedSparkSession with MockitoSugar with .newBuilder() .setPlan(plan) .setUserContext(context) - .setSessionId("session") + .setSessionId(sessionId) .build() // The observer is executed inside this thread. So @@ -599,6 +603,7 @@ class SparkConnectServiceSuite extends SharedSparkSession with MockitoSugar with .newBuilder() .setPlan(plan) .setUserContext(context) + .setSessionId(UUID.randomUUID.toString()) .build() // Execute plan. @@ -637,7 +642,7 @@ class SparkConnectServiceSuite extends SharedSparkSession with MockitoSugar with } } - protected def withCommandTest(f: VerifyEvents => Unit): Unit = { + protected def withCommandTest(sessionId: String)(f: VerifyEvents => Unit): Unit = { withView("testview") { withTable("testcat.testtable") { withSparkConf( @@ -649,7 +654,7 @@ class SparkConnectServiceSuite extends SharedSparkSession with MockitoSugar with when(restartedQuery.id).thenReturn(DEFAULT_UUID) when(restartedQuery.runId).thenReturn(DEFAULT_UUID) SparkConnectService.streamingSessionManager.registerNewStreamingQuery( - SparkConnectService.getOrCreateIsolatedSession("c1", "s1"), + SparkConnectService.getOrCreateIsolatedSession("c1", sessionId), restartedQuery) f(verifyEvents) } diff --git a/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/service/AddArtifactsHandlerSuite.scala b/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/service/AddArtifactsHandlerSuite.scala index f11c9b2969e..2e199bff5e7 100644 --- a/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/service/AddArtifactsHandlerSuite.scala +++ b/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/service/AddArtifactsHandlerSuite.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql.connect.service import java.io.InputStream import java.nio.file.{Files, Path} +import java.util.UUID import scala.collection.JavaConverters._ import scala.collection.mutable @@ -37,6 +38,8 @@ class AddArtifactsHandlerSuite extends SharedSparkSession with ResourceHelper { private val CHUNK_SIZE: Int = 32 * 1024 + private val sessionId = UUID.randomUUID.toString() + class DummyStreamObserver(p: Promise[AddArtifactsResponse]) extends StreamObserver[AddArtifactsResponse] { override def onNext(v: AddArtifactsResponse): Unit = p.success(v) @@ -125,7 +128,7 @@ class AddArtifactsHandlerSuite extends SharedSparkSession with ResourceHelper { val singleChunkArtifactRequest = AddArtifactsRequest .newBuilder() - .setSessionId("abc") + .setSessionId(sessionId) .setUserContext(context) .setBatch( proto.AddArtifactsRequest.Batch.newBuilder().addArtifacts(singleChunkArtifact).build()) @@ -168,7 +171,7 @@ class AddArtifactsHandlerSuite extends SharedSparkSession with ResourceHelper { val requestBuilder = AddArtifactsRequest .newBuilder() - .setSessionId("abc") + .setSessionId(sessionId) .setUserContext(context) .setBeginChunk(beginChunkedArtifact) @@ -295,7 +298,7 @@ class AddArtifactsHandlerSuite extends SharedSparkSession with ResourceHelper { val singleChunkArtifactRequest = AddArtifactsRequest .newBuilder() - .setSessionId("abc") + .setSessionId(sessionId) .setUserContext(context) .setBatch( proto.AddArtifactsRequest.Batch.newBuilder().addArtifacts(singleChunkArtifact).build()) @@ -336,7 +339,7 @@ class AddArtifactsHandlerSuite extends SharedSparkSession with ResourceHelper { val singleChunkArtifactRequest = AddArtifactsRequest .newBuilder() - .setSessionId("abc") + .setSessionId(sessionId) .setUserContext(context) .setBatch( proto.AddArtifactsRequest.Batch.newBuilder().addArtifacts(singleChunkArtifact).build()) @@ -353,7 +356,7 @@ class AddArtifactsHandlerSuite extends SharedSparkSession with ResourceHelper { val beginChunkArtifactRequest = AddArtifactsRequest .newBuilder() - .setSessionId("abc") + .setSessionId(sessionId) .setUserContext(context) .setBeginChunk(beginChunkedArtifact) .build() diff --git a/python/pyspark/sql/connect/proto/base_pb2.pyi b/python/pyspark/sql/connect/proto/base_pb2.pyi index 651438ea438..6059d38bd19 100644 --- a/python/pyspark/sql/connect/proto/base_pb2.pyi +++ b/python/pyspark/sql/connect/proto/base_pb2.pyi @@ -499,6 +499,7 @@ class AnalyzePlanRequest(google.protobuf.message.Message): The session_id specifies a spark session for a user id (which is specified by user_context.user_id). The session_id is set by the client to be able to collate streaming responses from different queries within the dedicated session. + The id should be an UUID string of the format `00112233-4455-6677-8899-aabbccddeeff` """ @property def user_context(self) -> global___UserContext: @@ -1042,6 +1043,7 @@ class ExecutePlanRequest(google.protobuf.message.Message): The session_id specifies a spark session for a user id (which is specified by user_context.user_id). The session_id is set by the client to be able to collate streaming responses from different queries within the dedicated session. + The id should be an UUID string of the format `00112233-4455-6677-8899-aabbccddeeff` """ @property def user_context(self) -> global___UserContext: @@ -1722,6 +1724,7 @@ class ConfigRequest(google.protobuf.message.Message): The session_id specifies a spark session for a user id (which is specified by user_context.user_id). The session_id is set by the client to be able to collate streaming responses from different queries within the dedicated session. + The id should be an UUID string of the format `00112233-4455-6677-8899-aabbccddeeff` """ @property def user_context(self) -> global___UserContext: @@ -1961,6 +1964,7 @@ class AddArtifactsRequest(google.protobuf.message.Message): The session_id specifies a spark session for a user id (which is specified by user_context.user_id). The session_id is set by the client to be able to collate streaming responses from different queries within the dedicated session. + The id should be an UUID string of the format `00112233-4455-6677-8899-aabbccddeeff` """ @property def user_context(self) -> global___UserContext: @@ -2112,6 +2116,7 @@ class ArtifactStatusesRequest(google.protobuf.message.Message): The session_id specifies a spark session for a user id (which is specified by user_context.user_id). The session_id is set by the client to be able to collate streaming responses from different queries within the dedicated session. + The id should be an UUID string of the format `00112233-4455-6677-8899-aabbccddeeff` """ @property def user_context(self) -> global___UserContext: @@ -2278,6 +2283,7 @@ class InterruptRequest(google.protobuf.message.Message): The session_id specifies a spark session for a user id (which is specified by user_context.user_id). The session_id is set by the client to be able to collate streaming responses from different queries within the dedicated session. + The id should be an UUID string of the format `00112233-4455-6677-8899-aabbccddeeff` """ @property def user_context(self) -> global___UserContext: --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org