This is an automated email from the ASF dual-hosted git repository. hvanhovell pushed a commit to branch branch-3.4 in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/branch-3.4 by this push: new bd8c9bda623 [SPARK-42533][CONNECT][SCALA] Add ssl for Scala client bd8c9bda623 is described below commit bd8c9bda623ea9dddd67f8f09cb1d197cf8906b1 Author: Zhen Li <zhenli...@users.noreply.github.com> AuthorDate: Fri Feb 24 08:44:23 2023 -0400 [SPARK-42533][CONNECT][SCALA] Add ssl for Scala client ### What changes were proposed in this pull request? Adding SSL encryption and access token support for Scala client ### Why are the changes needed? To support basic client side encryption to protect data sent over the network. ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? Unit tests. Manual tests. Closes #40133 from zhenlineo/ssl. Authored-by: Zhen Li <zhenli...@users.noreply.github.com> Signed-off-by: Herman van Hovell <her...@databricks.com> (cherry picked from commit c0d301ea3c3f6e3d1b10373823e0aeeb997e8daf) Signed-off-by: Herman van Hovell <her...@databricks.com> --- .../scala/org/apache/spark/sql/SparkSession.scala | 7 +- .../sql/connect/client/SparkConnectClient.scala | 193 +++++++++++++++++++-- .../apache/spark/sql/PlanGenerationTestSuite.scala | 2 +- .../connect/client/SparkConnectClientSuite.scala | 76 ++++---- 4 files changed, 229 insertions(+), 49 deletions(-) diff --git a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/SparkSession.scala b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/SparkSession.scala index b086db09365..0e5aaace20d 100644 --- a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/SparkSession.scala +++ b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/SparkSession.scala @@ -224,7 +224,12 @@ object SparkSession extends Logging { class Builder() extends Logging { private var _client: SparkConnectClient = _ - def client(client: SparkConnectClient): Builder = { + def remote(connectionString: String): Builder = { + client(SparkConnectClient.builder().connectionString(connectionString).build()) + this + } + + private[sql] def client(client: SparkConnectClient): Builder = { _client = client this } diff --git a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/connect/client/SparkConnectClient.scala b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/connect/client/SparkConnectClient.scala index 3049a0a0a5d..12bb581880c 100644 --- a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/connect/client/SparkConnectClient.scala +++ b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/connect/client/SparkConnectClient.scala @@ -17,21 +17,22 @@ package org.apache.spark.sql.connect.client -import scala.language.existentials - -import io.grpc.{ManagedChannel, ManagedChannelBuilder} +import io.grpc.{CallCredentials, CallOptions, Channel, ClientCall, ClientInterceptor, CompositeChannelCredentials, ForwardingClientCall, Grpc, InsecureChannelCredentials, ManagedChannel, Metadata, MethodDescriptor, Status, TlsChannelCredentials} import java.net.URI import java.util.UUID +import java.util.concurrent.Executor import org.apache.spark.connect.proto +import org.apache.spark.connect.proto.UserContext import org.apache.spark.sql.connect.common.config.ConnectCommon /** * Conceptually the remote spark session that communicates with the server. */ -class SparkConnectClient( +private[sql] class SparkConnectClient( private val userContext: proto.UserContext, - private val channel: ManagedChannel) { + private val channel: ManagedChannel, + private[client] val userAgent: String) { private[this] val stub = proto.SparkConnectServiceGrpc.newBlockingStub(channel) @@ -40,7 +41,7 @@ class SparkConnectClient( * @return * User ID. */ - def userId: String = userContext.getUserId() + private[client] def userId: String = userContext.getUserId() // Generate a unique session ID for this client. This UUID must be unique to allow // concurrent Spark sessions of the same user. If the channel is closed, creating @@ -60,6 +61,8 @@ class SparkConnectClient( .newBuilder() .setPlan(plan) .setUserContext(userContext) + .setClientId(sessionId) + .setClientType(userAgent) .build() stub.executePlan(request) } @@ -77,6 +80,7 @@ class SparkConnectClient( .setExplain(proto.Explain.newBuilder().setExplainMode(mode)) .setUserContext(userContext) .setClientId(sessionId) + .setClientType(userAgent) .build() analyze(request) } @@ -89,7 +93,21 @@ class SparkConnectClient( } } -object SparkConnectClient { +private[sql] object SparkConnectClient { + + private val DEFAULT_USER_AGENT: String = "_SPARK_CONNECT_SCALA" + + private val AUTH_TOKEN_META_DATA_KEY: Metadata.Key[String] = + Metadata.Key.of("Authentication", Metadata.ASCII_STRING_MARSHALLER) + + private val AUTH_TOKEN_ON_INSECURE_CONN_ERROR_MSG: String = + "Authentication token cannot be passed over insecure connections. " + + "Either remove 'token' or set 'use_ssl=true'" + + // for internal tests + def apply(userContext: UserContext, channel: ManagedChannel): SparkConnectClient = + new SparkConnectClient(userContext, channel, DEFAULT_USER_AGENT) + def builder(): Builder = new Builder() /** @@ -98,14 +116,27 @@ object SparkConnectClient { */ class Builder() { private val userContextBuilder = proto.UserContext.newBuilder() + private var userAgent: Option[String] = None + private var host: String = "localhost" private var port: Int = ConnectCommon.CONNECT_GRPC_BINDING_PORT + private var token: Option[String] = None + // If no value specified for isSslEnabled, default to false + private var isSslEnabled: Option[Boolean] = None + + private var metadata: Map[String, String] = Map.empty + def userId(id: String): Builder = { userContextBuilder.setUserId(id) this } + def userName(name: String): Builder = { + userContextBuilder.setUserName(name) + this + } + def host(inputHost: String): Builder = { require(inputHost != null) host = inputHost @@ -117,10 +148,58 @@ object SparkConnectClient { this } + /** + * Setting the token implicitly sets the use_ssl=true. All the following examples yield the + * same results: + * + * {{{ + * sc://localhost/;token=aaa + * sc://localhost/;use_ssl=true;token=aaa + * sc://localhost/;token=aaa;use_ssl=true + * }}} + * + * Throws exception if the token is set but use_ssl=false. + * + * @param inputToken + * the user token. + * @return + * this builder. + */ + def token(inputToken: String): Builder = { + require(inputToken != null && inputToken.nonEmpty) + token = Some(inputToken) + // Only set the isSSlEnabled if it is not yet set + isSslEnabled match { + case None => isSslEnabled = Some(true) + case Some(false) => + throw new IllegalArgumentException(AUTH_TOKEN_ON_INSECURE_CONN_ERROR_MSG) + case Some(true) => // Good, the ssl is enabled + } + this + } + + def enableSsl(): Builder = { + isSslEnabled = Some(true) + this + } + + /** + * Disables the SSL. Throws exception if the token has been set. + * + * @return + * this builder. + */ + def disableSsl(): Builder = { + require(token.isEmpty, AUTH_TOKEN_ON_INSECURE_CONN_ERROR_MSG) + isSslEnabled = Some(false) + this + } + private object URIParams { val PARAM_USER_ID = "user_id" val PARAM_USE_SSL = "use_ssl" val PARAM_TOKEN = "token" + val PARAM_USER_AGENT = "user_agent" } private def verifyURI(uri: URI): Unit = { @@ -146,6 +225,12 @@ object SparkConnectClient { } } + def userAgent(value: String): Builder = { + require(value != null) + userAgent = Some(value) + this + } + private def parseURIParams(uri: URI): Unit = { val params = uri.getPath.split(';').drop(1).filter(_ != "") params.foreach { kv => @@ -158,13 +243,13 @@ object SparkConnectClient { } (arr(0), arr(1)) } - if (key == URIParams.PARAM_USER_ID) { - userContextBuilder.setUserId(value) - } else { - // TODO(SPARK-41917): Support SSL and Auth tokens. - throw new UnsupportedOperationException( - "Parameters apart from user_id" + - " are currently unsupported.") + key match { + case URIParams.PARAM_USER_ID => userId(value) + case URIParams.PARAM_USER_AGENT => userAgent(value) + case URIParams.PARAM_TOKEN => token(value) + case URIParams.PARAM_USE_SSL => + if (java.lang.Boolean.valueOf(value)) enableSsl() else disableSsl() + case _ => this.metadata = this.metadata + (key -> value) } } } @@ -176,7 +261,6 @@ object SparkConnectClient { * Note: The connection string, if used, will override any previous host/port settings. */ def connectionString(connectionString: String): Builder = { - // TODO(SPARK-41917): Support SSL and Auth tokens. val uri = new URI(connectionString) verifyURI(uri) parseURIParams(uri) @@ -189,9 +273,84 @@ object SparkConnectClient { } def build(): SparkConnectClient = { - val channelBuilder = ManagedChannelBuilder.forAddress(host, port).usePlaintext() + val creds = isSslEnabled match { + case Some(false) | None => InsecureChannelCredentials.create() + case Some(true) => + token match { + case Some(t) => + // With access token added in the http header. + CompositeChannelCredentials.create( + TlsChannelCredentials.create, + new AccessTokenCallCredentials(t)) + case None => + TlsChannelCredentials.create + } + } + + val channelBuilder = Grpc.newChannelBuilderForAddress(host, port, creds) + if (metadata.nonEmpty) { + channelBuilder.intercept(new MetadataHeaderClientInterceptor(metadata)) + } val channel: ManagedChannel = channelBuilder.build() - new SparkConnectClient(userContextBuilder.build(), channel) + new SparkConnectClient( + userContextBuilder.build(), + channel, + userAgent.getOrElse(DEFAULT_USER_AGENT)) + } + } + + /** + * A [[CallCredentials]] created from an access token. + * + * @param token + * A string to place directly in the http request authorization header, for example + * "authorization: Bearer <access_token>". + */ + private[client] class AccessTokenCallCredentials(token: String) extends CallCredentials { + override def applyRequestMetadata( + requestInfo: CallCredentials.RequestInfo, + appExecutor: Executor, + applier: CallCredentials.MetadataApplier): Unit = { + appExecutor.execute(() => { + try { + val headers = new Metadata() + headers.put(AUTH_TOKEN_META_DATA_KEY, s"Bearer $token"); + applier.apply(headers) + } catch { + case e: Throwable => + applier.fail(Status.UNAUTHENTICATED.withCause(e)); + } + }) + } + + override def thisUsesUnstableApi(): Unit = { + // Marks this API is not stable. Left empty on purpose. + } + } + + /** + * A client interceptor to pass extra parameters in http request header. + * + * @param metadata + * extra metadata placed in the http request header, for example "key: value". + */ + private[client] class MetadataHeaderClientInterceptor(metadata: Map[String, String]) + extends ClientInterceptor { + override def interceptCall[ReqT, RespT]( + method: MethodDescriptor[ReqT, RespT], + callOptions: CallOptions, + next: Channel): ClientCall[ReqT, RespT] = { + new ForwardingClientCall.SimpleForwardingClientCall[ReqT, RespT]( + next.newCall(method, callOptions)) { + override def start( + responseListener: ClientCall.Listener[RespT], + headers: Metadata): Unit = { + metadata.foreach { case (key, value) => + headers.put(Metadata.Key.of(key, Metadata.ASCII_STRING_MARSHALLER), value) + } + super.start(responseListener, headers) + } + } } } } diff --git a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/PlanGenerationTestSuite.scala b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/PlanGenerationTestSuite.scala index 6a54cc88aec..b759471e777 100644 --- a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/PlanGenerationTestSuite.scala +++ b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/PlanGenerationTestSuite.scala @@ -99,7 +99,7 @@ class PlanGenerationTestSuite extends ConnectFunSuite with BeforeAndAfterAll wit override protected def beforeAll(): Unit = { super.beforeAll() - val client = new SparkConnectClient( + val client = SparkConnectClient( proto.UserContext.newBuilder().build(), InProcessChannelBuilder.forName("/dev/null").build()) val builder = SparkSession.builder().client(client) diff --git a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/SparkConnectClientSuite.scala b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/SparkConnectClientSuite.scala index 908eddbe7bf..98dacbcab89 100644 --- a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/SparkConnectClientSuite.scala +++ b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/SparkConnectClientSuite.scala @@ -18,7 +18,7 @@ package org.apache.spark.sql.connect.client import java.util.concurrent.TimeUnit -import io.grpc.Server +import io.grpc.{Server, StatusRuntimeException} import io.grpc.netty.NettyServerBuilder import io.grpc.stub.StreamObserver import org.scalatest.BeforeAndAfterEach @@ -65,10 +65,11 @@ class SparkConnectClientSuite assert(client.userId == "abc123") } - private def testClientConnection( - client: SparkConnectClient, - serverPort: Int = ConnectCommon.CONNECT_GRPC_BINDING_PORT): Unit = { + // Use 0 to start the server at a random port + private def testClientConnection(serverPort: Int = 0)( + clientBuilder: Int => SparkConnectClient): Unit = { startDummyServer(serverPort) + client = clientBuilder(server.getPort) val request = AnalyzePlanRequest .newBuilder() .setClientId("abc123") @@ -79,15 +80,28 @@ class SparkConnectClientSuite } test("Test connection") { - val testPort = 16001 - client = SparkConnectClient.builder().port(testPort).build() - testClientConnection(client, testPort) + testClientConnection() { testPort => SparkConnectClient.builder().port(testPort).build() } } test("Test connection string") { - val testPort = 16000 - client = SparkConnectClient.builder().connectionString("sc://localhost:16000").build() - testClientConnection(client, testPort) + testClientConnection() { testPort => + SparkConnectClient.builder().connectionString(s"sc://localhost:$testPort").build() + } + } + + test("Test encryption") { + startDummyServer(0) + client = SparkConnectClient + .builder() + .connectionString(s"sc://localhost:${server.getPort}/;use_ssl=true") + .build() + + val request = AnalyzePlanRequest.newBuilder().setClientId("abc123").build() + + // Failed the ssl handshake as the dummy server does not have any server credentials installed. + assertThrows[StatusRuntimeException] { + client.analyze(request) + } } private case class TestPackURI( @@ -97,17 +111,27 @@ class SparkConnectClientSuite private val URIs = Seq[TestPackURI]( TestPackURI("sc://host", isCorrect = true), - TestPackURI("sc://localhost/", isCorrect = true, client => testClientConnection(client)), + TestPackURI( + "sc://localhost/", + isCorrect = true, + client => testClientConnection(ConnectCommon.CONNECT_GRPC_BINDING_PORT)(_ => client)), TestPackURI( "sc://localhost:1234/", isCorrect = true, - client => testClientConnection(client, 1234)), - TestPackURI("sc://localhost/;", isCorrect = true, client => testClientConnection(client)), + client => testClientConnection(1234)(_ => client)), + TestPackURI( + "sc://localhost/;", + isCorrect = true, + client => testClientConnection(ConnectCommon.CONNECT_GRPC_BINDING_PORT)(_ => client)), TestPackURI("sc://host:123", isCorrect = true), TestPackURI( "sc://host:123/;user_id=a94", isCorrect = true, client => assert(client.userId == "a94")), + TestPackURI( + "sc://host:123/;user_agent=a945", + isCorrect = true, + client => assert(client.userAgent == "a945")), TestPackURI("scc://host:12", isCorrect = false), TestPackURI("http://host", isCorrect = false), TestPackURI("sc:/host:1234/path", isCorrect = false), @@ -116,7 +140,15 @@ class SparkConnectClientSuite TestPackURI("sc://host:123;user_id=a94", isCorrect = false), TestPackURI("sc:///user_id=123", isCorrect = false), TestPackURI("sc://host:-4", isCorrect = false), - TestPackURI("sc://:123/", isCorrect = false)) + TestPackURI("sc://:123/", isCorrect = false), + TestPackURI("sc://host:123/;use_ssl=true", isCorrect = true), + TestPackURI("sc://host:123/;token=mySecretToken", isCorrect = true), + TestPackURI("sc://host:123/;token=", isCorrect = false), + TestPackURI("sc://host:123/;use_ssl=true;token=mySecretToken", isCorrect = true), + TestPackURI("sc://host:123/;token=mySecretToken;use_ssl=true", isCorrect = true), + TestPackURI("sc://host:123/;use_ssl=false;token=mySecretToken", isCorrect = false), + TestPackURI("sc://host:123/;token=mySecretToken;use_ssl=false", isCorrect = false), + TestPackURI("sc://host:123/;param1=value1;param2=value2", isCorrect = true)) private def checkTestPack(testPack: TestPackURI): Unit = { val client = SparkConnectClient.builder().connectionString(testPack.connectionString).build() @@ -132,22 +164,6 @@ class SparkConnectClientSuite } } } - - // TODO(SPARK-41917): Remove test once SSL and Auth tokens are supported. - test("Non user-id parameters throw unsupported errors") { - assertThrows[UnsupportedOperationException] { - SparkConnectClient.builder().connectionString("sc://host/;use_ssl=true").build() - } - - assertThrows[UnsupportedOperationException] { - SparkConnectClient.builder().connectionString("sc://host/;token=abc").build() - } - - assertThrows[UnsupportedOperationException] { - SparkConnectClient.builder().connectionString("sc://host/;xyz=abc").build() - - } - } } class DummySparkConnectService() extends SparkConnectServiceGrpc.SparkConnectServiceImplBase { --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org