This is an automated email from the ASF dual-hosted git repository. gurwls223 pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push: new a3bd477a6d8 [SPARK-44425][CONNECT] Validate that user provided sessionId is an UUID a3bd477a6d8 is described below commit a3bd477a6d8c317ee1e9a6aae6ebd2ef4fc67cce Author: Juliusz Sompolski <ju...@databricks.com> AuthorDate: Fri Jul 28 07:55:02 2023 +0900 [SPARK-44425][CONNECT] Validate that user provided sessionId is an UUID ### What changes were proposed in this pull request? We want to validate that user provided sessionId is an UUID. Existing Spark Connect python and scala clients already do that, we would like to depend on it being in this format moving forward, just like we already validate that operatoinId is an UUID. ### Why are the changes needed? Validate what's already assumed. ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? Existing CI. Closes #42150 from juliuszsompolski/SPARK-44425. Authored-by: Juliusz Sompolski <ju...@databricks.com> Signed-off-by: Hyukjin Kwon <gurwls...@apache.org> --- .../src/main/protobuf/spark/connect/base.proto | 6 ++++++ .../sql/connect/service/SparkConnectService.scala | 12 ++++++++++- .../connect/artifact/ArtifactManagerSuite.scala | 15 +++++++++----- .../connect/planner/SparkConnectServiceSuite.scala | 23 +++++++++++++--------- .../connect/service/AddArtifactsHandlerSuite.scala | 13 +++++++----- python/pyspark/sql/connect/proto/base_pb2.pyi | 6 ++++++ 6 files changed, 55 insertions(+), 20 deletions(-) diff --git a/connector/connect/common/src/main/protobuf/spark/connect/base.proto b/connector/connect/common/src/main/protobuf/spark/connect/base.proto index d935ae65328..21fd167f6b5 100644 --- a/connector/connect/common/src/main/protobuf/spark/connect/base.proto +++ b/connector/connect/common/src/main/protobuf/spark/connect/base.proto @@ -63,6 +63,7 @@ message AnalyzePlanRequest { // The session_id specifies a spark session for a user id (which is specified // by user_context.user_id). The session_id is set by the client to be able to // collate streaming responses from different queries within the dedicated session. + // The id should be an UUID string of the format `00112233-4455-6677-8899-aabbccddeeff` string session_id = 1; // (Required) User context @@ -273,6 +274,7 @@ message ExecutePlanRequest { // The session_id specifies a spark session for a user id (which is specified // by user_context.user_id). The session_id is set by the client to be able to // collate streaming responses from different queries within the dedicated session. + // The id should be an UUID string of the format `00112233-4455-6677-8899-aabbccddeeff` string session_id = 1; // (Required) User context @@ -407,6 +409,7 @@ message ConfigRequest { // The session_id specifies a spark session for a user id (which is specified // by user_context.user_id). The session_id is set by the client to be able to // collate streaming responses from different queries within the dedicated session. + // The id should be an UUID string of the format `00112233-4455-6677-8899-aabbccddeeff` string session_id = 1; // (Required) User context @@ -492,6 +495,7 @@ message AddArtifactsRequest { // The session_id specifies a spark session for a user id (which is specified // by user_context.user_id). The session_id is set by the client to be able to // collate streaming responses from different queries within the dedicated session. + // The id should be an UUID string of the format `00112233-4455-6677-8899-aabbccddeeff` string session_id = 1; // User context @@ -581,6 +585,7 @@ message ArtifactStatusesRequest { // The session_id specifies a spark session for a user id (which is specified // by user_context.user_id). The session_id is set by the client to be able to // collate streaming responses from different queries within the dedicated session. + // The id should be an UUID string of the format `00112233-4455-6677-8899-aabbccddeeff` string session_id = 1; // User context @@ -617,6 +622,7 @@ message InterruptRequest { // The session_id specifies a spark session for a user id (which is specified // by user_context.user_id). The session_id is set by the client to be able to // collate streaming responses from different queries within the dedicated session. + // The id should be an UUID string of the format `00112233-4455-6677-8899-aabbccddeeff` string session_id = 1; // (Required) User context diff --git a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectService.scala b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectService.scala index 6b7007130be..87e4f21732f 100644 --- a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectService.scala +++ b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectService.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql.connect.service import java.net.InetSocketAddress +import java.util.UUID import java.util.concurrent.TimeUnit import com.google.common.base.Ticker @@ -28,7 +29,7 @@ import io.grpc.protobuf.services.ProtoReflectionService import io.grpc.stub.StreamObserver import org.apache.commons.lang3.StringUtils -import org.apache.spark.SparkEnv +import org.apache.spark.{SparkEnv, SparkSQLException} import org.apache.spark.connect.proto import org.apache.spark.connect.proto.{AddArtifactsRequest, AddArtifactsResponse} import org.apache.spark.internal.Logging @@ -221,6 +222,15 @@ object SparkConnectService extends Logging { * Based on the `key` find or create a new SparkSession. */ def getOrCreateIsolatedSession(userId: String, sessionId: String): SessionHolder = { + // Validate that sessionId is formatted like UUID before creating session. + try { + UUID.fromString(sessionId).toString + } catch { + case _: IllegalArgumentException => + throw new SparkSQLException( + errorClass = "INVALID_HANDLE.FORMAT", + messageParameters = Map("handle" -> sessionId)) + } userSessionMapping.get( (userId, sessionId), () => { diff --git a/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/artifact/ArtifactManagerSuite.scala b/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/artifact/ArtifactManagerSuite.scala index 199290327cf..fa3b7d52379 100644 --- a/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/artifact/ArtifactManagerSuite.scala +++ b/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/artifact/ArtifactManagerSuite.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql.connect.artifact import java.nio.charset.StandardCharsets import java.nio.file.{Files, Paths} +import java.util.UUID import org.apache.commons.io.FileUtils @@ -96,7 +97,8 @@ class ArtifactManagerSuite extends SharedSparkSession with ResourceHelper { val remotePath = Paths.get("classes/Hello.class") assert(stagingPath.toFile.exists()) - val sessionHolder = SparkConnectService.getOrCreateIsolatedSession("c1", "session") + val sessionHolder = + SparkConnectService.getOrCreateIsolatedSession("c1", UUID.randomUUID.toString()) sessionHolder.addArtifact(remotePath, stagingPath, None) val movedClassFile = SparkConnectArtifactManager.artifactRootPath @@ -208,9 +210,11 @@ class ArtifactManagerSuite extends SharedSparkSession with ResourceHelper { } test("Classloaders for spark sessions are isolated") { - val holder1 = SparkConnectService.getOrCreateIsolatedSession("c1", "session1") - val holder2 = SparkConnectService.getOrCreateIsolatedSession("c2", "session2") - val holder3 = SparkConnectService.getOrCreateIsolatedSession("c3", "session3") + // use same sessionId - different users should still make it isolated. + val sessionId = UUID.randomUUID.toString() + val holder1 = SparkConnectService.getOrCreateIsolatedSession("c1", sessionId) + val holder2 = SparkConnectService.getOrCreateIsolatedSession("c2", sessionId) + val holder3 = SparkConnectService.getOrCreateIsolatedSession("c3", sessionId) def addHelloClass(holder: SessionHolder): Unit = { val copyDir = Utils.createTempDir().toPath @@ -267,7 +271,8 @@ class ArtifactManagerSuite extends SharedSparkSession with ResourceHelper { val stagingPath = copyDir.resolve("Hello.class") val remotePath = Paths.get("classes/Hello.class") - val sessionHolder = SparkConnectService.getOrCreateIsolatedSession("c1", "session") + val sessionHolder = + SparkConnectService.getOrCreateIsolatedSession("c1", UUID.randomUUID.toString) sessionHolder.addArtifact(remotePath, stagingPath, None) val sessionDirectory = diff --git a/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectServiceSuite.scala b/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectServiceSuite.scala index d820e65a685..c29a9b9b629 100644 --- a/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectServiceSuite.scala +++ b/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectServiceSuite.scala @@ -167,6 +167,7 @@ class SparkConnectServiceSuite extends SharedSparkSession with MockitoSugar with .newBuilder() .setPlan(plan) .setUserContext(context) + .setSessionId(UUID.randomUUID.toString()) .build() // Execute plan. @@ -334,7 +335,8 @@ class SparkConnectServiceSuite extends SharedSparkSession with MockitoSugar with .setCommand(ByteString.copyFrom("command".getBytes())) .setPythonVer("3.10") .build())))) { command => - withCommandTest { verifyEvents => + val sessionId = UUID.randomUUID.toString() + withCommandTest(sessionId) { verifyEvents => val instance = new SparkConnectService(false) val context = proto.UserContext .newBuilder() @@ -348,7 +350,7 @@ class SparkConnectServiceSuite extends SharedSparkSession with MockitoSugar with val request = proto.ExecutePlanRequest .newBuilder() .setPlan(plan) - .setSessionId("s1") + .setSessionId(sessionId) .setUserContext(context) .build() @@ -393,11 +395,12 @@ class SparkConnectServiceSuite extends SharedSparkSession with MockitoSugar with } test("SPARK-43923: canceled request send events") { + val sessionId = UUID.randomUUID.toString withEvents { verifyEvents => val instance = new SparkConnectService(false) // Add an always crashing UDF - val session = SparkConnectService.getOrCreateIsolatedSession("c1", "session").session + val session = SparkConnectService.getOrCreateIsolatedSession("c1", sessionId).session val sleep: Long => Long = { time => Thread.sleep(time) time @@ -417,7 +420,7 @@ class SparkConnectServiceSuite extends SharedSparkSession with MockitoSugar with .newBuilder() .setPlan(plan) .setUserContext(context) - .setSessionId("session") + .setSessionId(sessionId) .build() val thread = new Thread { @@ -426,7 +429,7 @@ class SparkConnectServiceSuite extends SharedSparkSession with MockitoSugar with instance.interrupt( proto.InterruptRequest .newBuilder() - .setSessionId("session") + .setSessionId(sessionId) .setUserContext(context) .setInterruptType(proto.InterruptRequest.InterruptType.INTERRUPT_TYPE_ALL) .build(), @@ -463,11 +466,12 @@ class SparkConnectServiceSuite extends SharedSparkSession with MockitoSugar with } test("SPARK-41165: failures in the arrow collect path should not cause hangs") { + val sessionId = UUID.randomUUID.toString withEvents { verifyEvents => val instance = new SparkConnectService(false) // Add an always crashing UDF - val session = SparkConnectService.getOrCreateIsolatedSession("c1", "session").session + val session = SparkConnectService.getOrCreateIsolatedSession("c1", sessionId).session val instaKill: Long => Long = { _ => throw new Exception("Kaboom") } @@ -486,7 +490,7 @@ class SparkConnectServiceSuite extends SharedSparkSession with MockitoSugar with .newBuilder() .setPlan(plan) .setUserContext(context) - .setSessionId("session") + .setSessionId(sessionId) .build() // The observer is executed inside this thread. So @@ -599,6 +603,7 @@ class SparkConnectServiceSuite extends SharedSparkSession with MockitoSugar with .newBuilder() .setPlan(plan) .setUserContext(context) + .setSessionId(UUID.randomUUID.toString()) .build() // Execute plan. @@ -637,7 +642,7 @@ class SparkConnectServiceSuite extends SharedSparkSession with MockitoSugar with } } - protected def withCommandTest(f: VerifyEvents => Unit): Unit = { + protected def withCommandTest(sessionId: String)(f: VerifyEvents => Unit): Unit = { withView("testview") { withTable("testcat.testtable") { withSparkConf( @@ -649,7 +654,7 @@ class SparkConnectServiceSuite extends SharedSparkSession with MockitoSugar with when(restartedQuery.id).thenReturn(DEFAULT_UUID) when(restartedQuery.runId).thenReturn(DEFAULT_UUID) SparkConnectService.streamingSessionManager.registerNewStreamingQuery( - SparkConnectService.getOrCreateIsolatedSession("c1", "s1"), + SparkConnectService.getOrCreateIsolatedSession("c1", sessionId), restartedQuery) f(verifyEvents) } diff --git a/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/service/AddArtifactsHandlerSuite.scala b/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/service/AddArtifactsHandlerSuite.scala index f11c9b2969e..2e199bff5e7 100644 --- a/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/service/AddArtifactsHandlerSuite.scala +++ b/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/service/AddArtifactsHandlerSuite.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql.connect.service import java.io.InputStream import java.nio.file.{Files, Path} +import java.util.UUID import scala.collection.JavaConverters._ import scala.collection.mutable @@ -37,6 +38,8 @@ class AddArtifactsHandlerSuite extends SharedSparkSession with ResourceHelper { private val CHUNK_SIZE: Int = 32 * 1024 + private val sessionId = UUID.randomUUID.toString() + class DummyStreamObserver(p: Promise[AddArtifactsResponse]) extends StreamObserver[AddArtifactsResponse] { override def onNext(v: AddArtifactsResponse): Unit = p.success(v) @@ -125,7 +128,7 @@ class AddArtifactsHandlerSuite extends SharedSparkSession with ResourceHelper { val singleChunkArtifactRequest = AddArtifactsRequest .newBuilder() - .setSessionId("abc") + .setSessionId(sessionId) .setUserContext(context) .setBatch( proto.AddArtifactsRequest.Batch.newBuilder().addArtifacts(singleChunkArtifact).build()) @@ -168,7 +171,7 @@ class AddArtifactsHandlerSuite extends SharedSparkSession with ResourceHelper { val requestBuilder = AddArtifactsRequest .newBuilder() - .setSessionId("abc") + .setSessionId(sessionId) .setUserContext(context) .setBeginChunk(beginChunkedArtifact) @@ -295,7 +298,7 @@ class AddArtifactsHandlerSuite extends SharedSparkSession with ResourceHelper { val singleChunkArtifactRequest = AddArtifactsRequest .newBuilder() - .setSessionId("abc") + .setSessionId(sessionId) .setUserContext(context) .setBatch( proto.AddArtifactsRequest.Batch.newBuilder().addArtifacts(singleChunkArtifact).build()) @@ -336,7 +339,7 @@ class AddArtifactsHandlerSuite extends SharedSparkSession with ResourceHelper { val singleChunkArtifactRequest = AddArtifactsRequest .newBuilder() - .setSessionId("abc") + .setSessionId(sessionId) .setUserContext(context) .setBatch( proto.AddArtifactsRequest.Batch.newBuilder().addArtifacts(singleChunkArtifact).build()) @@ -353,7 +356,7 @@ class AddArtifactsHandlerSuite extends SharedSparkSession with ResourceHelper { val beginChunkArtifactRequest = AddArtifactsRequest .newBuilder() - .setSessionId("abc") + .setSessionId(sessionId) .setUserContext(context) .setBeginChunk(beginChunkedArtifact) .build() diff --git a/python/pyspark/sql/connect/proto/base_pb2.pyi b/python/pyspark/sql/connect/proto/base_pb2.pyi index 651438ea438..6059d38bd19 100644 --- a/python/pyspark/sql/connect/proto/base_pb2.pyi +++ b/python/pyspark/sql/connect/proto/base_pb2.pyi @@ -499,6 +499,7 @@ class AnalyzePlanRequest(google.protobuf.message.Message): The session_id specifies a spark session for a user id (which is specified by user_context.user_id). The session_id is set by the client to be able to collate streaming responses from different queries within the dedicated session. + The id should be an UUID string of the format `00112233-4455-6677-8899-aabbccddeeff` """ @property def user_context(self) -> global___UserContext: @@ -1042,6 +1043,7 @@ class ExecutePlanRequest(google.protobuf.message.Message): The session_id specifies a spark session for a user id (which is specified by user_context.user_id). The session_id is set by the client to be able to collate streaming responses from different queries within the dedicated session. + The id should be an UUID string of the format `00112233-4455-6677-8899-aabbccddeeff` """ @property def user_context(self) -> global___UserContext: @@ -1722,6 +1724,7 @@ class ConfigRequest(google.protobuf.message.Message): The session_id specifies a spark session for a user id (which is specified by user_context.user_id). The session_id is set by the client to be able to collate streaming responses from different queries within the dedicated session. + The id should be an UUID string of the format `00112233-4455-6677-8899-aabbccddeeff` """ @property def user_context(self) -> global___UserContext: @@ -1961,6 +1964,7 @@ class AddArtifactsRequest(google.protobuf.message.Message): The session_id specifies a spark session for a user id (which is specified by user_context.user_id). The session_id is set by the client to be able to collate streaming responses from different queries within the dedicated session. + The id should be an UUID string of the format `00112233-4455-6677-8899-aabbccddeeff` """ @property def user_context(self) -> global___UserContext: @@ -2112,6 +2116,7 @@ class ArtifactStatusesRequest(google.protobuf.message.Message): The session_id specifies a spark session for a user id (which is specified by user_context.user_id). The session_id is set by the client to be able to collate streaming responses from different queries within the dedicated session. + The id should be an UUID string of the format `00112233-4455-6677-8899-aabbccddeeff` """ @property def user_context(self) -> global___UserContext: @@ -2278,6 +2283,7 @@ class InterruptRequest(google.protobuf.message.Message): The session_id specifies a spark session for a user id (which is specified by user_context.user_id). The session_id is set by the client to be able to collate streaming responses from different queries within the dedicated session. + The id should be an UUID string of the format `00112233-4455-6677-8899-aabbccddeeff` """ @property def user_context(self) -> global___UserContext: --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org