This is an automated email from the ASF dual-hosted git repository. mridulm80 pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push: new 08c7bacb4f9 [SPARK-45545][CORE] Pass SSLOptions wherever we create a SparkTransportConf 08c7bacb4f9 is described below commit 08c7bacb4f9dc0343ba9730b00d792cec7a1cf1e Author: Hasnain Lakhani <hasnain.lakh...@databricks.com> AuthorDate: Wed Oct 25 01:35:56 2023 -0500 [SPARK-45545][CORE] Pass SSLOptions wherever we create a SparkTransportConf ### What changes were proposed in this pull request? This change ensures that RPC SSL options settings inheritance works properly after https://github.com/apache/spark/pull/43238 - we pass `sslOptions` wherever we call `fromSparkConf`. In addition to that minor mechanical change, duplicate/add tests for every place that calls this method, to add a test case that runs with SSL support in the config. ### Why are the changes needed? These changes are needed to ensure that the RPC SSL functionality can work properly with settings inheritance. In addition, through these tests we can ensure that any changes to these modules are also tested with SSL support and avoid regressions in the future. ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? Full integration testing also done as part of https://github.com/apache/spark/pull/42685 Added some tests and ran them: ``` build/sbt > project core > testOnly org.apache.spark.*Ssl* > testOnly org.apache.spark.network.netty.NettyBlockTransferSecuritySuite ``` and ``` build/sbt -Pyarn > project yarn > testOnly org.apache.spark.network.yarn.SslYarnShuffleServiceWithRocksDBBackendSuite ``` ### Was this patch authored or co-authored using generative AI tooling? No Closes #43387 from hasnain-db/spark-tls-integrate-everywhere. Authored-by: Hasnain Lakhani <hasnain.lakh...@databricks.com> Signed-off-by: Mridul Muralidharan <mridul<at>gmail.com> --- .../scala/org/apache/spark/SecurityManager.scala | 6 +++ .../src/main/scala/org/apache/spark/SparkEnv.scala | 7 ++- .../spark/deploy/ExternalShuffleService.scala | 6 ++- .../executor/CoarseGrainedExecutorBackend.scala | 4 +- .../network/netty/NettyBlockTransferService.scala | 6 ++- .../org/apache/spark/rpc/netty/NettyRpcEnv.scala | 10 ++++- .../spark/shuffle/IndexShuffleBlockResolver.scala | 8 +++- .../apache/spark/shuffle/ShuffleBlockPusher.scala | 6 ++- .../org/apache/spark/storage/BlockManager.scala | 3 +- .../apache/spark/ExternalShuffleServiceSuite.scala | 8 +++- .../spark/SslExternalShuffleServiceSuite.scala | 52 ++++++++++++++++++++++ .../org/apache/spark/SslShuffleNettySuite.scala | 26 +++++++++++ .../CoarseGrainedExecutorBackendSuite.scala | 40 +++++++++++------ .../netty/NettyBlockTransferSecuritySuite.scala | 45 +++++++++++++------ .../scala/org/apache/spark/rpc/RpcEnvSuite.scala | 52 ++++++++++++---------- .../apache/spark/rpc/netty/NettyRpcEnvSuite.scala | 12 +++-- .../spark/shuffle/ShuffleBlockPusherSuite.scala | 12 ++++- .../sort/IndexShuffleBlockResolverSuite.scala | 14 +++++- .../storage/BlockManagerReplicationSuite.scala | 40 ++++++++++------- .../storage/SslBlockManagerReplicationSuite.scala | 39 ++++++++++++++++ .../scala/org/apache/spark/util/SslTestUtils.scala | 35 +++++++++++++++ .../network/yarn/SslYarnShuffleServiceSuite.scala | 34 ++++++++++++++ 22 files changed, 379 insertions(+), 86 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/SecurityManager.scala b/core/src/main/scala/org/apache/spark/SecurityManager.scala index f8961fff8e1..ee9051d024c 100644 --- a/core/src/main/scala/org/apache/spark/SecurityManager.scala +++ b/core/src/main/scala/org/apache/spark/SecurityManager.scala @@ -292,6 +292,12 @@ private[spark] class SecurityManager( */ def isSslRpcEnabled(): Boolean = sslRpcEnabled + /** + * Returns the SSLOptions object for the RPC namespace + * @return the SSLOptions object for the RPC namespace + */ + def getRpcSSLOptions(): SSLOptions = rpcSSLOptions + /** * Gets the user used for authenticating SASL connections. * For now use a single hardcoded user. diff --git a/core/src/main/scala/org/apache/spark/SparkEnv.scala b/core/src/main/scala/org/apache/spark/SparkEnv.scala index 310dc828440..c2bae41d34e 100644 --- a/core/src/main/scala/org/apache/spark/SparkEnv.scala +++ b/core/src/main/scala/org/apache/spark/SparkEnv.scala @@ -374,7 +374,12 @@ object SparkEnv extends Logging { } val externalShuffleClient = if (conf.get(config.SHUFFLE_SERVICE_ENABLED)) { - val transConf = SparkTransportConf.fromSparkConf(conf, "shuffle", numUsableCores) + val transConf = SparkTransportConf.fromSparkConf( + conf, + "shuffle", + numUsableCores, + sslOptions = Some(securityManager.getRpcSSLOptions()) + ) Some(new ExternalBlockStoreClient(transConf, securityManager, securityManager.isAuthenticationEnabled(), conf.get(config.SHUFFLE_REGISTRATION_TIMEOUT))) } else { diff --git a/core/src/main/scala/org/apache/spark/deploy/ExternalShuffleService.scala b/core/src/main/scala/org/apache/spark/deploy/ExternalShuffleService.scala index 466c1f2e14b..a56fbd5a644 100644 --- a/core/src/main/scala/org/apache/spark/deploy/ExternalShuffleService.scala +++ b/core/src/main/scala/org/apache/spark/deploy/ExternalShuffleService.scala @@ -53,7 +53,11 @@ class ExternalShuffleService(sparkConf: SparkConf, securityManager: SecurityMana private val registeredExecutorsDB = "registeredExecutors" private val transportConf = - SparkTransportConf.fromSparkConf(sparkConf, "shuffle", numUsableCores = 0) + SparkTransportConf.fromSparkConf( + sparkConf, + "shuffle", + numUsableCores = 0, + sslOptions = Some(securityManager.getRpcSSLOptions())) private val blockHandler = newShuffleBlockHandler(transportConf) private var transportContext: TransportContext = _ diff --git a/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala b/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala index b074ac814a9..f964e2b50b5 100644 --- a/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala +++ b/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala @@ -90,7 +90,9 @@ private[spark] class CoarseGrainedExecutorBackend( logInfo("Connecting to driver: " + driverUrl) try { - val shuffleClientTransportConf = SparkTransportConf.fromSparkConf(env.conf, "shuffle") + val securityManager = new SecurityManager(env.conf) + val shuffleClientTransportConf = SparkTransportConf.fromSparkConf( + env.conf, "shuffle", sslOptions = Some(securityManager.getRpcSSLOptions())) if (NettyUtils.preferDirectBufs(shuffleClientTransportConf) && PlatformDependent.maxDirectMemory() < env.conf.get(MAX_REMOTE_BLOCK_SIZE_FETCH_TO_MEM)) { throw new SparkException(s"Netty direct memory should at least be bigger than " + diff --git a/core/src/main/scala/org/apache/spark/network/netty/NettyBlockTransferService.scala b/core/src/main/scala/org/apache/spark/network/netty/NettyBlockTransferService.scala index f54383db4c0..6b785a07c7f 100644 --- a/core/src/main/scala/org/apache/spark/network/netty/NettyBlockTransferService.scala +++ b/core/src/main/scala/org/apache/spark/network/netty/NettyBlockTransferService.scala @@ -70,7 +70,11 @@ private[spark] class NettyBlockTransferService( val rpcHandler = new NettyBlockRpcServer(conf.getAppId, serializer, blockDataManager) var serverBootstrap: Option[TransportServerBootstrap] = None var clientBootstrap: Option[TransportClientBootstrap] = None - this.transportConf = SparkTransportConf.fromSparkConf(conf, "shuffle", numCores) + this.transportConf = SparkTransportConf.fromSparkConf( + conf, + "shuffle", + numCores, + sslOptions = Some(securityManager.getRpcSSLOptions())) if (authEnabled) { serverBootstrap = Some(new AuthServerBootstrap(transportConf, securityManager)) clientBootstrap = Some(new AuthClientBootstrap(transportConf, conf.getAppId, securityManager)) diff --git a/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcEnv.scala b/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcEnv.scala index 464b6cbc6b0..7909f2327cd 100644 --- a/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcEnv.scala +++ b/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcEnv.scala @@ -56,7 +56,9 @@ private[netty] class NettyRpcEnv( conf.clone.set(RPC_IO_NUM_CONNECTIONS_PER_PEER, 1), "rpc", conf.get(RPC_IO_THREADS).getOrElse(numUsableCores), - role) + role, + sslOptions = Some(securityManager.getRpcSSLOptions()) + ) private val dispatcher: Dispatcher = new Dispatcher(this, numUsableCores) @@ -391,7 +393,11 @@ private[netty] class NettyRpcEnv( } val ioThreads = clone.getInt("spark.files.io.threads", 1) - val downloadConf = SparkTransportConf.fromSparkConf(clone, module, ioThreads) + val downloadConf = SparkTransportConf.fromSparkConf( + clone, + module, + ioThreads, + sslOptions = Some(securityManager.getRpcSSLOptions())) val downloadContext = new TransportContext(downloadConf, new NoOpRpcHandler(), true) fileDownloadFactory = downloadContext.createClientFactory(createClientBootstraps()) } diff --git a/core/src/main/scala/org/apache/spark/shuffle/IndexShuffleBlockResolver.scala b/core/src/main/scala/org/apache/spark/shuffle/IndexShuffleBlockResolver.scala index 919b0f5f7c1..ab34bae996c 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/IndexShuffleBlockResolver.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/IndexShuffleBlockResolver.scala @@ -24,7 +24,7 @@ import java.nio.file.Files import scala.collection.mutable.ArrayBuffer -import org.apache.spark.{SparkConf, SparkEnv, SparkException} +import org.apache.spark.{SecurityManager, SparkConf, SparkEnv, SparkException} import org.apache.spark.errors.SparkCoreErrors import org.apache.spark.internal.{config, Logging} import org.apache.spark.io.NioBufferedFileInputStream @@ -58,7 +58,11 @@ private[spark] class IndexShuffleBlockResolver( private lazy val blockManager = Option(_blockManager).getOrElse(SparkEnv.get.blockManager) - private val transportConf = SparkTransportConf.fromSparkConf(conf, "shuffle") + private val transportConf = { + val securityManager = new SecurityManager(conf) + SparkTransportConf.fromSparkConf( + conf, "shuffle", sslOptions = Some(securityManager.getRpcSSLOptions())) + } private val remoteShuffleMaxDisk: Option[Long] = conf.get(config.STORAGE_DECOMMISSION_SHUFFLE_MAX_DISK_SIZE) diff --git a/core/src/main/scala/org/apache/spark/shuffle/ShuffleBlockPusher.scala b/core/src/main/scala/org/apache/spark/shuffle/ShuffleBlockPusher.scala index ac43ba8b56f..252f929da28 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/ShuffleBlockPusher.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/ShuffleBlockPusher.scala @@ -25,7 +25,7 @@ import java.util.concurrent.ExecutorService import scala.collection.mutable.{ArrayBuffer, HashMap, HashSet, Queue} import scala.util.control.NonFatal -import org.apache.spark.{ShuffleDependency, SparkConf, SparkContext, SparkEnv} +import org.apache.spark.{SecurityManager, ShuffleDependency, SparkConf, SparkContext, SparkEnv} import org.apache.spark.annotation.Since import org.apache.spark.executor.{CoarseGrainedExecutorBackend, ExecutorBackend} import org.apache.spark.internal.Logging @@ -108,7 +108,9 @@ private[spark] class ShuffleBlockPusher(conf: SparkConf) extends Logging { dep: ShuffleDependency[_, _, _], mapIndex: Int): Unit = { val numPartitions = dep.partitioner.numPartitions - val transportConf = SparkTransportConf.fromSparkConf(conf, "shuffle") + val securityManager = new SecurityManager(conf) + val transportConf = SparkTransportConf.fromSparkConf( + conf, "shuffle", sslOptions = Some(securityManager.getRpcSSLOptions())) this.shuffleId = dep.shuffleId this.shuffleMergeId = dep.shuffleMergeId this.mapIndex = mapIndex diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala index aa9ba7c34f6..f77fda46149 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala @@ -1252,7 +1252,8 @@ private[spark] class BlockManager( new EncryptedBlockData(file, blockSize, conf, key)) case _ => - val transportConf = SparkTransportConf.fromSparkConf(conf, "shuffle") + val transportConf = SparkTransportConf.fromSparkConf( + conf, "shuffle", sslOptions = Some(securityManager.getRpcSSLOptions())) new FileSegmentManagedBuffer(transportConf, file, 0, file.length) } Some(managedBuffer) diff --git a/core/src/test/scala/org/apache/spark/ExternalShuffleServiceSuite.scala b/core/src/test/scala/org/apache/spark/ExternalShuffleServiceSuite.scala index 1d1bb9e9eee..f0a63247e64 100644 --- a/core/src/test/scala/org/apache/spark/ExternalShuffleServiceSuite.scala +++ b/core/src/test/scala/org/apache/spark/ExternalShuffleServiceSuite.scala @@ -49,8 +49,7 @@ class ExternalShuffleServiceSuite extends ShuffleSuite with BeforeAndAfterAll wi var transportContext: TransportContext = _ var rpcHandler: ExternalBlockHandler = _ - override def beforeAll(): Unit = { - super.beforeAll() + protected def initializeHandlers(): Unit = { val transportConf = SparkTransportConf.fromSparkConf(conf, "shuffle", numUsableCores = 2) rpcHandler = new ExternalBlockHandler(transportConf, null) transportContext = new TransportContext(transportConf, rpcHandler) @@ -61,6 +60,11 @@ class ExternalShuffleServiceSuite extends ShuffleSuite with BeforeAndAfterAll wi conf.set(config.SHUFFLE_SERVICE_PORT, server.getPort) } + override def beforeAll(): Unit = { + super.beforeAll() + initializeHandlers() + } + override def afterAll(): Unit = { Utils.tryLogNonFatalError{ server.close() diff --git a/core/src/test/scala/org/apache/spark/SslExternalShuffleServiceSuite.scala b/core/src/test/scala/org/apache/spark/SslExternalShuffleServiceSuite.scala new file mode 100644 index 00000000000..3ce1f11a7ac --- /dev/null +++ b/core/src/test/scala/org/apache/spark/SslExternalShuffleServiceSuite.scala @@ -0,0 +1,52 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark + +import org.apache.spark.deploy.SparkHadoopUtil +import org.apache.spark.internal.config +import org.apache.spark.network.TransportContext +import org.apache.spark.network.netty.SparkTransportConf +import org.apache.spark.network.shuffle.ExternalBlockHandler + +/** + * This suite creates an external shuffle server and routes all shuffle fetches through it. + * Note that failures in this suite may arise due to changes in Spark that invalidate expectations + * set up in `ExternalShuffleBlockHandler`, such as changing the format of shuffle files or how + * we hash files into folders. + */ +class SslExternalShuffleServiceSuite extends ExternalShuffleServiceSuite { + + override def initializeHandlers(): Unit = { + SslTestUtils.updateWithSSLConfig(conf) + val hadoopConf = SparkHadoopUtil.get.newConfiguration(conf); + // Show that we can successfully inherit options defined in the `spark.ssl` namespace + val defaultSslOptions = SSLOptions.parse(conf, hadoopConf, "spark.ssl") + val sslOptions = SSLOptions.parse( + conf, hadoopConf, "spark.ssl.rpc", defaults = Some(defaultSslOptions)) + val transportConf = SparkTransportConf.fromSparkConf( + conf, "shuffle", numUsableCores = 2, sslOptions = Some(sslOptions)) + + rpcHandler = new ExternalBlockHandler(transportConf, null) + transportContext = new TransportContext(transportConf, rpcHandler) + server = transportContext.createServer() + + conf.set(config.SHUFFLE_MANAGER, "sort") + conf.set(config.SHUFFLE_SERVICE_ENABLED, true) + conf.set(config.SHUFFLE_SERVICE_PORT, server.getPort) + } +} diff --git a/core/src/test/scala/org/apache/spark/SslShuffleNettySuite.scala b/core/src/test/scala/org/apache/spark/SslShuffleNettySuite.scala new file mode 100644 index 00000000000..7eaff7d37a8 --- /dev/null +++ b/core/src/test/scala/org/apache/spark/SslShuffleNettySuite.scala @@ -0,0 +1,26 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark + +class SslShuffleNettySuite extends ShuffleNettySuite { + + override def beforeAll(): Unit = { + super.beforeAll() + SslTestUtils.updateWithSSLConfig(conf) + } +} diff --git a/core/src/test/scala/org/apache/spark/executor/CoarseGrainedExecutorBackendSuite.scala b/core/src/test/scala/org/apache/spark/executor/CoarseGrainedExecutorBackendSuite.scala index 7a7021357ed..3ef4da6d3d3 100644 --- a/core/src/test/scala/org/apache/spark/executor/CoarseGrainedExecutorBackendSuite.scala +++ b/core/src/test/scala/org/apache/spark/executor/CoarseGrainedExecutorBackendSuite.scala @@ -52,8 +52,12 @@ class CoarseGrainedExecutorBackendSuite extends SparkFunSuite implicit val formats = DefaultFormats + def createSparkConf(): SparkConf = { + new SparkConf() + } + test("parsing no resources") { - val conf = new SparkConf + val conf = createSparkConf() val resourceProfile = ResourceProfile.getOrCreateDefaultProfile(conf) val serializer = new JavaSerializer(conf) val env = createMockEnv(conf, serializer) @@ -75,7 +79,7 @@ class CoarseGrainedExecutorBackendSuite extends SparkFunSuite } test("parsing one resource") { - val conf = new SparkConf + val conf = createSparkConf() conf.set(EXECUTOR_GPU_ID.amountConf, "2") val serializer = new JavaSerializer(conf) val env = createMockEnv(conf, serializer) @@ -100,11 +104,11 @@ class CoarseGrainedExecutorBackendSuite extends SparkFunSuite val ereqs = new ExecutorResourceRequests().resource(GPU, 2) ereqs.resource(FPGA, 3) val rp = rpBuilder.require(ereqs).build() - testParsingMultipleResources(new SparkConf, rp) + testParsingMultipleResources(createSparkConf(), rp) } test("parsing multiple resources") { - val conf = new SparkConf + val conf = createSparkConf() conf.set(EXECUTOR_GPU_ID.amountConf, "2") conf.set(EXECUTOR_FPGA_ID.amountConf, "3") testParsingMultipleResources(conf, ResourceProfile.getOrCreateDefaultProfile(conf)) @@ -136,7 +140,7 @@ class CoarseGrainedExecutorBackendSuite extends SparkFunSuite } test("error checking parsing resources and executor and task configs") { - val conf = new SparkConf + val conf = createSparkConf() conf.set(EXECUTOR_GPU_ID.amountConf, "2") val serializer = new JavaSerializer(conf) val env = createMockEnv(conf, serializer) @@ -178,11 +182,11 @@ class CoarseGrainedExecutorBackendSuite extends SparkFunSuite val ereqs = new ExecutorResourceRequests().resource(GPU, 4) val treqs = new TaskResourceRequests().resource(GPU, 1) val rp = rpBuilder.require(ereqs).require(treqs).build() - testExecutorResourceFoundLessThanRequired(new SparkConf, rp) + testExecutorResourceFoundLessThanRequired(createSparkConf(), rp) } test("executor resource found less than required") { - val conf = new SparkConf() + val conf = createSparkConf() conf.set(EXECUTOR_GPU_ID.amountConf, "4") conf.set(TASK_GPU_ID.amountConf, "1") testExecutorResourceFoundLessThanRequired(conf, ResourceProfile.getOrCreateDefaultProfile(conf)) @@ -213,7 +217,7 @@ class CoarseGrainedExecutorBackendSuite extends SparkFunSuite } test("use resource discovery") { - val conf = new SparkConf + val conf = createSparkConf() conf.set(EXECUTOR_FPGA_ID.amountConf, "3") assume(!(Utils.isWindows)) withTempDir { dir => @@ -246,7 +250,7 @@ class CoarseGrainedExecutorBackendSuite extends SparkFunSuite val ereqs = new ExecutorResourceRequests().resource(FPGA, 3, scriptPath) ereqs.resource(GPU, 2) val rp = rpBuilder.require(ereqs).build() - allocatedFileAndConfigsResourceDiscoveryTestFpga(dir, new SparkConf, rp) + allocatedFileAndConfigsResourceDiscoveryTestFpga(dir, createSparkConf(), rp) } } @@ -255,7 +259,7 @@ class CoarseGrainedExecutorBackendSuite extends SparkFunSuite withTempDir { dir => val scriptPath = createTempScriptWithExpectedOutput(dir, "fpgaDiscoverScript", """{"name": "fpga","addresses":["f1", "f2", "f3"]}""") - val conf = new SparkConf + val conf = createSparkConf() conf.set(EXECUTOR_FPGA_ID.amountConf, "3") conf.set(EXECUTOR_FPGA_ID.discoveryScriptConf, scriptPath) conf.set(EXECUTOR_GPU_ID.amountConf, "2") @@ -289,7 +293,7 @@ class CoarseGrainedExecutorBackendSuite extends SparkFunSuite } test("track allocated resources by taskId") { - val conf = new SparkConf + val conf = createSparkConf() val securityMgr = new SecurityManager(conf) val serializer = new JavaSerializer(conf) var backend: CoarseGrainedExecutorBackend = null @@ -389,7 +393,7 @@ class CoarseGrainedExecutorBackendSuite extends SparkFunSuite * being executed in [[Executor.TaskRunner]]. */ test(s"Tasks launched should always be cancelled.") { - val conf = new SparkConf + val conf = createSparkConf() val securityMgr = new SecurityManager(conf) val serializer = new JavaSerializer(conf) val threadPool = ThreadUtils.newDaemonFixedThreadPool(32, "test-executor") @@ -478,7 +482,7 @@ class CoarseGrainedExecutorBackendSuite extends SparkFunSuite * it has not been launched yet. */ test(s"Tasks not launched should always be cancelled.") { - val conf = new SparkConf + val conf = createSparkConf() val securityMgr = new SecurityManager(conf) val serializer = new JavaSerializer(conf) val threadPool = ThreadUtils.newDaemonFixedThreadPool(32, "test-executor") @@ -567,7 +571,7 @@ class CoarseGrainedExecutorBackendSuite extends SparkFunSuite * [[SparkUncaughtExceptionHandler]] and [[Executor]] can exit by itself. */ test("SPARK-40320 Executor should exit when initialization failed for fatal error") { - val conf = new SparkConf() + val conf = createSparkConf() .setMaster("local-cluster[1, 1, 1024]") .set(PLUGINS, Seq(classOf[TestFatalErrorPlugin].getName)) .setAppName("test") @@ -628,3 +632,11 @@ private class TestErrorExecutorPlugin extends ExecutorPlugin { // scalastyle:on throwerror } } + +class SslCoarseGrainedExecutorBackendSuite extends CoarseGrainedExecutorBackendSuite + with LocalSparkContext with MockitoSugar { + + override def createSparkConf(): SparkConf = { + SslTestUtils.updateWithSSLConfig(super.createSparkConf()) + } +} diff --git a/core/src/test/scala/org/apache/spark/network/netty/NettyBlockTransferSecuritySuite.scala b/core/src/test/scala/org/apache/spark/network/netty/NettyBlockTransferSecuritySuite.scala index 85b05cd5f98..5c234ef9550 100644 --- a/core/src/test/scala/org/apache/spark/network/netty/NettyBlockTransferSecuritySuite.scala +++ b/core/src/test/scala/org/apache/spark/network/netty/NettyBlockTransferSecuritySuite.scala @@ -32,7 +32,7 @@ import org.scalatest.matchers.must.Matchers import org.scalatest.matchers.should.Matchers._ import org.scalatestplus.mockito.MockitoSugar -import org.apache.spark.{SecurityManager, SparkConf, SparkFunSuite} +import org.apache.spark.{SecurityManager, SparkConf, SparkFunSuite, SslTestUtils} import org.apache.spark.internal.config._ import org.apache.spark.internal.config.Network import org.apache.spark.network.{BlockDataManager, BlockTransferService} @@ -43,8 +43,15 @@ import org.apache.spark.storage.{BlockId, ShuffleBlockId} import org.apache.spark.util.ThreadUtils class NettyBlockTransferSecuritySuite extends SparkFunSuite with MockitoSugar with Matchers { + + def createSparkConf(): SparkConf = { + new SparkConf() + } + + def isRunningWithSSL(): Boolean = false + test("security default off") { - val conf = new SparkConf() + val conf = createSparkConf() .set("spark.app.id", "app-id") testConnection(conf, conf) match { case Success(_) => // expected @@ -53,7 +60,7 @@ class NettyBlockTransferSecuritySuite extends SparkFunSuite with MockitoSugar wi } test("security on same password") { - val conf = new SparkConf() + val conf = createSparkConf() .set(NETWORK_AUTH_ENABLED, true) .set(AUTH_SECRET, "good") .set("spark.app.id", "app-id") @@ -64,7 +71,7 @@ class NettyBlockTransferSecuritySuite extends SparkFunSuite with MockitoSugar wi } test("security on mismatch password") { - val conf0 = new SparkConf() + val conf0 = createSparkConf() .set(NETWORK_AUTH_ENABLED, true) .set(AUTH_SECRET, "good") .set("spark.app.id", "app-id") @@ -76,7 +83,7 @@ class NettyBlockTransferSecuritySuite extends SparkFunSuite with MockitoSugar wi } test("security mismatch auth off on server") { - val conf0 = new SparkConf() + val conf0 = createSparkConf() .set(NETWORK_AUTH_ENABLED, true) .set(AUTH_SECRET, "good") .set("spark.app.id", "app-id") @@ -100,15 +107,17 @@ class NettyBlockTransferSecuritySuite extends SparkFunSuite with MockitoSugar wi } test("security with aes encryption") { - val conf = new SparkConf() - .set(NETWORK_AUTH_ENABLED, true) - .set(AUTH_SECRET, "good") - .set("spark.app.id", "app-id") - .set(Network.NETWORK_CRYPTO_ENABLED, true) - .set(Network.NETWORK_CRYPTO_SASL_FALLBACK, false) - testConnection(conf, conf) match { - case Success(_) => // expected - case Failure(t) => fail(t) + if (!isRunningWithSSL()) { + val conf = new SparkConf() + .set(NETWORK_AUTH_ENABLED, true) + .set(AUTH_SECRET, "good") + .set("spark.app.id", "app-id") + .set(Network.NETWORK_CRYPTO_ENABLED, true) + .set(Network.NETWORK_CRYPTO_SASL_FALLBACK, false) + testConnection(conf, conf) match { + case Success(_) => // expected + case Failure(t) => fail(t) + } } } @@ -179,3 +188,11 @@ class NettyBlockTransferSecuritySuite extends SparkFunSuite with MockitoSugar wi } } +class SslNettyBlockTransferSecuritySuite extends NettyBlockTransferSecuritySuite { + + override def isRunningWithSSL(): Boolean = true + + override def createSparkConf(): SparkConf = { + SslTestUtils.updateWithSSLConfig(super.createSparkConf()) + } +} diff --git a/core/src/test/scala/org/apache/spark/rpc/RpcEnvSuite.scala b/core/src/test/scala/org/apache/spark/rpc/RpcEnvSuite.scala index a88be983b80..3ef38257351 100644 --- a/core/src/test/scala/org/apache/spark/rpc/RpcEnvSuite.scala +++ b/core/src/test/scala/org/apache/spark/rpc/RpcEnvSuite.scala @@ -44,9 +44,13 @@ abstract class RpcEnvSuite extends SparkFunSuite { var env: RpcEnv = _ + def createSparkConf(): SparkConf = { + new SparkConf() + } + override def beforeAll(): Unit = { super.beforeAll() - val conf = new SparkConf() + val conf = createSparkConf() env = createRpcEnv(conf, "local", 0) val sparkEnv = mock(classOf[SparkEnv]) @@ -93,7 +97,7 @@ abstract class RpcEnvSuite extends SparkFunSuite { } }) - val anotherEnv = createRpcEnv(new SparkConf(), "remote", 0, clientMode = true) + val anotherEnv = createRpcEnv(createSparkConf(), "remote", 0, clientMode = true) // Use anotherEnv to find out the RpcEndpointRef val rpcEndpointRef = anotherEnv.setupEndpointRef(env.address, "send-remotely") try { @@ -145,7 +149,7 @@ abstract class RpcEnvSuite extends SparkFunSuite { } }) - val anotherEnv = createRpcEnv(new SparkConf(), "remote", 0, clientMode = true) + val anotherEnv = createRpcEnv(createSparkConf(), "remote", 0, clientMode = true) // Use anotherEnv to find out the RpcEndpointRef val rpcEndpointRef = anotherEnv.setupEndpointRef(env.address, "ask-remotely") try { @@ -168,7 +172,7 @@ abstract class RpcEnvSuite extends SparkFunSuite { } }) - val conf = new SparkConf() + val conf = createSparkConf() val shortProp = "spark.rpc.short.timeout" val anotherEnv = createRpcEnv(conf, "remote", 0, clientMode = true) // Use anotherEnv to find out the RpcEndpointRef @@ -198,7 +202,7 @@ abstract class RpcEnvSuite extends SparkFunSuite { } }) - val conf = new SparkConf() + val conf = createSparkConf() val shortProp = "spark.rpc.short.timeout" val anotherEnv = createRpcEnv(conf, "remote", 0, clientMode = true) // Use anotherEnv to find out the RpcEndpointRef @@ -467,7 +471,7 @@ abstract class RpcEnvSuite extends SparkFunSuite { } }) - val anotherEnv = createRpcEnv(new SparkConf(), "remote", 0, clientMode = true) + val anotherEnv = createRpcEnv(createSparkConf(), "remote", 0, clientMode = true) // Use anotherEnv to find out the RpcEndpointRef val rpcEndpointRef = anotherEnv.setupEndpointRef(env.address, "sendWithReply-remotely") try { @@ -507,7 +511,7 @@ abstract class RpcEnvSuite extends SparkFunSuite { } }) - val anotherEnv = createRpcEnv(new SparkConf(), "remote", 0, clientMode = true) + val anotherEnv = createRpcEnv(createSparkConf(), "remote", 0, clientMode = true) // Use anotherEnv to find out the RpcEndpointRef val rpcEndpointRef = anotherEnv.setupEndpointRef(env.address, "sendWithReply-remotely-error") try { @@ -556,8 +560,8 @@ abstract class RpcEnvSuite extends SparkFunSuite { } test("network events in sever RpcEnv when another RpcEnv is in server mode") { - val serverEnv1 = createRpcEnv(new SparkConf(), "server1", 0, clientMode = false) - val serverEnv2 = createRpcEnv(new SparkConf(), "server2", 0, clientMode = false) + val serverEnv1 = createRpcEnv(createSparkConf(), "server1", 0, clientMode = false) + val serverEnv2 = createRpcEnv(createSparkConf(), "server2", 0, clientMode = false) val (_, events) = setupNetworkEndpoint(serverEnv1, "network-events") val (serverRef2, _) = setupNetworkEndpoint(serverEnv2, "network-events") try { @@ -585,9 +589,9 @@ abstract class RpcEnvSuite extends SparkFunSuite { } test("network events in sever RpcEnv when another RpcEnv is in client mode") { - val serverEnv = createRpcEnv(new SparkConf(), "server", 0, clientMode = false) + val serverEnv = createRpcEnv(createSparkConf(), "server", 0, clientMode = false) val (serverRef, events) = setupNetworkEndpoint(serverEnv, "network-events") - val clientEnv = createRpcEnv(new SparkConf(), "client", 0, clientMode = true) + val clientEnv = createRpcEnv(createSparkConf(), "client", 0, clientMode = true) try { val serverRefInClient = clientEnv.setupEndpointRef(serverRef.address, serverRef.name) // Send a message to set up the connection @@ -615,8 +619,8 @@ abstract class RpcEnvSuite extends SparkFunSuite { } test("network events in client RpcEnv when another RpcEnv is in server mode") { - val clientEnv = createRpcEnv(new SparkConf(), "client", 0, clientMode = true) - val serverEnv = createRpcEnv(new SparkConf(), "server", 0, clientMode = false) + val clientEnv = createRpcEnv(createSparkConf(), "client", 0, clientMode = true) + val serverEnv = createRpcEnv(createSparkConf(), "server", 0, clientMode = false) val (_, events) = setupNetworkEndpoint(clientEnv, "network-events") val (serverRef, _) = setupNetworkEndpoint(serverEnv, "network-events") try { @@ -652,7 +656,7 @@ abstract class RpcEnvSuite extends SparkFunSuite { } }) - val anotherEnv = createRpcEnv(new SparkConf(), "remote", 0, clientMode = true) + val anotherEnv = createRpcEnv(createSparkConf(), "remote", 0, clientMode = true) // Use anotherEnv to find out the RpcEndpointRef val rpcEndpointRef = anotherEnv.setupEndpointRef(env.address, "sendWithReply-unserializable-error") @@ -669,7 +673,7 @@ abstract class RpcEnvSuite extends SparkFunSuite { } test("port conflict") { - val anotherEnv = createRpcEnv(new SparkConf(), "remote", env.address.port) + val anotherEnv = createRpcEnv(createSparkConf(), "remote", env.address.port) try { assert(anotherEnv.address.port != env.address.port) } finally { @@ -729,20 +733,20 @@ abstract class RpcEnvSuite extends SparkFunSuite { } test("send with authentication") { - testSend(new SparkConf() + testSend(createSparkConf() .set(NETWORK_AUTH_ENABLED, true) .set(AUTH_SECRET, "good")) } test("send with SASL encryption") { - testSend(new SparkConf() + testSend(createSparkConf() .set(NETWORK_AUTH_ENABLED, true) .set(AUTH_SECRET, "good") .set(SASL_ENCRYPTION_ENABLED, true)) } test("send with AES encryption") { - testSend(new SparkConf() + testSend(createSparkConf() .set(NETWORK_AUTH_ENABLED, true) .set(AUTH_SECRET, "good") .set(Network.NETWORK_CRYPTO_ENABLED, true) @@ -750,20 +754,20 @@ abstract class RpcEnvSuite extends SparkFunSuite { } test("ask with authentication") { - testAsk(new SparkConf() + testAsk(createSparkConf() .set(NETWORK_AUTH_ENABLED, true) .set(AUTH_SECRET, "good")) } test("ask with SASL encryption") { - testAsk(new SparkConf() + testAsk(createSparkConf() .set(NETWORK_AUTH_ENABLED, true) .set(AUTH_SECRET, "good") .set(SASL_ENCRYPTION_ENABLED, true)) } test("ask with AES encryption") { - testAsk(new SparkConf() + testAsk(createSparkConf() .set(NETWORK_AUTH_ENABLED, true) .set(AUTH_SECRET, "good") .set(Network.NETWORK_CRYPTO_ENABLED, true) @@ -861,7 +865,7 @@ abstract class RpcEnvSuite extends SparkFunSuite { test("file server") { withTempDir { tempDir => withTempDir { destDir => - val conf = new SparkConf() + val conf = createSparkConf() val file = new File(tempDir, "file") Files.write(UUID.randomUUID().toString(), file, UTF_8) @@ -940,7 +944,7 @@ abstract class RpcEnvSuite extends SparkFunSuite { } }) - val anotherEnv = createRpcEnv(new SparkConf(), "remote", 0) + val anotherEnv = createRpcEnv(createSparkConf(), "remote", 0) val endpoint = mock(classOf[RpcEndpoint]) anotherEnv.setupEndpoint("SPARK-14699", endpoint) @@ -960,7 +964,7 @@ abstract class RpcEnvSuite extends SparkFunSuite { test("isolated endpoints") { val latch = new CountDownLatch(1) val singleThreadedEnv = createRpcEnv( - new SparkConf().set(Network.RPC_NETTY_DISPATCHER_NUM_THREADS, 1), "singleThread", 0) + createSparkConf().set(Network.RPC_NETTY_DISPATCHER_NUM_THREADS, 1), "singleThread", 0) try { val blockingEndpoint = singleThreadedEnv .setupEndpoint("blocking", new IsolatedThreadSafeRpcEndpoint { diff --git a/core/src/test/scala/org/apache/spark/rpc/netty/NettyRpcEnvSuite.scala b/core/src/test/scala/org/apache/spark/rpc/netty/NettyRpcEnvSuite.scala index fe6d0db837b..dcd40b6afd5 100644 --- a/core/src/test/scala/org/apache/spark/rpc/netty/NettyRpcEnvSuite.scala +++ b/core/src/test/scala/org/apache/spark/rpc/netty/NettyRpcEnvSuite.scala @@ -53,7 +53,7 @@ class NettyRpcEnvSuite extends RpcEnvSuite with MockitoSugar with TimeLimits { } test("advertise address different from bind address") { - val sparkConf = new SparkConf() + val sparkConf = createSparkConf() val config = RpcEnvConfig(sparkConf, "test", "localhost", "example.com", 0, new SecurityManager(sparkConf), 0, false) val env = new NettyRpcEnvFactory().create(config) @@ -95,7 +95,7 @@ class NettyRpcEnvSuite extends RpcEnvSuite with MockitoSugar with TimeLimits { test("StackOverflowError should be sent back and Dispatcher should survive") { val numUsableCores = 2 - val conf = new SparkConf + val conf = createSparkConf() val config = RpcEnvConfig( conf, "test", @@ -150,7 +150,7 @@ class NettyRpcEnvSuite extends RpcEnvSuite with MockitoSugar with TimeLimits { context.reply(msg) } }) - val conf = new SparkConf() + val conf = createSparkConf() val anotherEnv = createRpcEnv(conf, "remote", 0, clientMode = true) // Use anotherEnv to find out the RpcEndpointRef val rpcEndpointRef = anotherEnv.setupEndpointRef(env.address, "ask-remotely-server") @@ -180,3 +180,9 @@ class NettyRpcEnvSuite extends RpcEnvSuite with MockitoSugar with TimeLimits { } } } + +class SslNettyRpcEnvSuite extends NettyRpcEnvSuite with MockitoSugar with TimeLimits { + override def createSparkConf(): SparkConf = { + SslTestUtils.updateWithSSLConfig(super.createSparkConf()) + } +} diff --git a/core/src/test/scala/org/apache/spark/shuffle/ShuffleBlockPusherSuite.scala b/core/src/test/scala/org/apache/spark/shuffle/ShuffleBlockPusherSuite.scala index 18c27ff1269..99f113ec16a 100644 --- a/core/src/test/scala/org/apache/spark/shuffle/ShuffleBlockPusherSuite.scala +++ b/core/src/test/scala/org/apache/spark/shuffle/ShuffleBlockPusherSuite.scala @@ -53,9 +53,13 @@ class ShuffleBlockPusherSuite extends SparkFunSuite { private var conf: SparkConf = _ private var pushedBlocks = new ArrayBuffer[String] + def createSparkConf(): SparkConf = { + new SparkConf(loadDefaults = false) + } + override def beforeEach(): Unit = { super.beforeEach() - conf = new SparkConf(loadDefaults = false) + conf = createSparkConf() MockitoAnnotations.openMocks(this).close() when(dependency.shuffleId).thenReturn(0) when(dependency.partitioner).thenReturn(new HashPartitioner(8)) @@ -480,3 +484,9 @@ class ShuffleBlockPusherSuite extends SparkFunSuite { } } } + +class SslShuffleBlockPusherSuite extends ShuffleBlockPusherSuite { + override def createSparkConf(): SparkConf = { + SslTestUtils.updateWithSSLConfig(super.createSparkConf()) + } +} diff --git a/core/src/test/scala/org/apache/spark/shuffle/sort/IndexShuffleBlockResolverSuite.scala b/core/src/test/scala/org/apache/spark/shuffle/sort/IndexShuffleBlockResolverSuite.scala index 31b255cff72..8a9537b4f18 100644 --- a/core/src/test/scala/org/apache/spark/shuffle/sort/IndexShuffleBlockResolverSuite.scala +++ b/core/src/test/scala/org/apache/spark/shuffle/sort/IndexShuffleBlockResolverSuite.scala @@ -26,7 +26,7 @@ import org.mockito.Mockito._ import org.mockito.invocation.InvocationOnMock import org.roaringbitmap.RoaringBitmap -import org.apache.spark.{SparkConf, SparkFunSuite} +import org.apache.spark.{SparkConf, SparkFunSuite, SslTestUtils} import org.apache.spark.internal.config import org.apache.spark.shuffle.{IndexShuffleBlockResolver, ShuffleBlockInfo} import org.apache.spark.storage._ @@ -37,8 +37,12 @@ class IndexShuffleBlockResolverSuite extends SparkFunSuite { @Mock(answer = RETURNS_SMART_NULLS) private var blockManager: BlockManager = _ @Mock(answer = RETURNS_SMART_NULLS) private var diskBlockManager: DiskBlockManager = _ + def createSparkConf(): SparkConf = { + new SparkConf(loadDefaults = false) + } + private var tempDir: File = _ - private val conf: SparkConf = new SparkConf(loadDefaults = false) + private val conf: SparkConf = createSparkConf() private val appId = "TESTAPP" override def beforeEach(): Unit = { @@ -275,3 +279,9 @@ class IndexShuffleBlockResolverSuite extends SparkFunSuite { assert(checksumsInMemory === checksumsFromFile) } } + +class SslIndexShuffleBlockResolverSuite extends IndexShuffleBlockResolverSuite { + override def createSparkConf(): SparkConf = { + SslTestUtils.updateWithSSLConfig(super.createSparkConf()) + } +} diff --git a/core/src/test/scala/org/apache/spark/storage/BlockManagerReplicationSuite.scala b/core/src/test/scala/org/apache/spark/storage/BlockManagerReplicationSuite.scala index 38a669bc857..1fbc900727c 100644 --- a/core/src/test/scala/org/apache/spark/storage/BlockManagerReplicationSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/BlockManagerReplicationSuite.scala @@ -50,7 +50,8 @@ trait BlockManagerReplicationBehavior extends SparkFunSuite with BeforeAndAfter with LocalSparkContext { - val conf: SparkConf + val conf: SparkConf = createConf() + protected def createConf(): SparkConf protected var rpcEnv: RpcEnv = null protected var master: BlockManagerMaster = null @@ -459,15 +460,21 @@ trait BlockManagerReplicationBehavior extends SparkFunSuite } class BlockManagerReplicationSuite extends BlockManagerReplicationBehavior { - val conf = new SparkConf(false).set("spark.app.id", "test") - conf.set(Kryo.KRYO_SERIALIZER_BUFFER_SIZE.key, "1m") + override def createConf(): SparkConf = { + new SparkConf(false) + .set("spark.app.id", "test") + .set(Kryo.KRYO_SERIALIZER_BUFFER_SIZE.key, "1m") + } } class BlockManagerProactiveReplicationSuite extends BlockManagerReplicationBehavior { - val conf = new SparkConf(false).set("spark.app.id", "test") - conf.set(Kryo.KRYO_SERIALIZER_BUFFER_SIZE.key, "1m") - conf.set(STORAGE_REPLICATION_PROACTIVE, true) - conf.set(STORAGE_EXCEPTION_PIN_LEAK, true) + override def createConf(): SparkConf = { + new SparkConf(false) + .set("spark.app.id", "test") + .set(Kryo.KRYO_SERIALIZER_BUFFER_SIZE.key, "1m") + .set(STORAGE_REPLICATION_PROACTIVE, true) + .set(STORAGE_EXCEPTION_PIN_LEAK, true) + } (2 to 5).foreach { i => test(s"proactive block replication - $i replicas - ${i - 1} block manager deletions") { @@ -539,14 +546,17 @@ class DummyTopologyMapper(conf: SparkConf) extends TopologyMapper(conf) with Log } class BlockManagerBasicStrategyReplicationSuite extends BlockManagerReplicationBehavior { - val conf: SparkConf = new SparkConf(false).set("spark.app.id", "test") - conf.set(Kryo.KRYO_SERIALIZER_BUFFER_SIZE.key, "1m") - conf.set( - STORAGE_REPLICATION_POLICY, - classOf[BasicBlockReplicationPolicy].getName) - conf.set( - STORAGE_REPLICATION_TOPOLOGY_MAPPER, - classOf[DummyTopologyMapper].getName) + override def createConf(): SparkConf = { + new SparkConf(false) + .set("spark.app.id", "test") + .set(Kryo.KRYO_SERIALIZER_BUFFER_SIZE.key, "1m") + .set( + STORAGE_REPLICATION_POLICY, + classOf[BasicBlockReplicationPolicy].getName) + .set( + STORAGE_REPLICATION_TOPOLOGY_MAPPER, + classOf[DummyTopologyMapper].getName) + } } // BlockReplicationPolicy to prioritize BlockManagers based on hostnames diff --git a/core/src/test/scala/org/apache/spark/storage/SslBlockManagerReplicationSuite.scala b/core/src/test/scala/org/apache/spark/storage/SslBlockManagerReplicationSuite.scala new file mode 100644 index 00000000000..760f31de059 --- /dev/null +++ b/core/src/test/scala/org/apache/spark/storage/SslBlockManagerReplicationSuite.scala @@ -0,0 +1,39 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.storage + +import org.apache.spark.{SparkConf, SslTestUtils} + +class SslBlockManagerReplicationSuite extends BlockManagerReplicationSuite { + override def createConf(): SparkConf = { + SslTestUtils.updateWithSSLConfig(super.createConf()) + } +} + +class SslBlockManagerProactiveReplicationSuite extends BlockManagerProactiveReplicationSuite { + override def createConf(): SparkConf = { + SslTestUtils.updateWithSSLConfig(super.createConf()) + } +} + +class SslBlockManagerBasicStrategyReplicationSuite + extends BlockManagerBasicStrategyReplicationSuite { + override def createConf(): SparkConf = { + SslTestUtils.updateWithSSLConfig(super.createConf()) + } +} diff --git a/core/src/test/scala/org/apache/spark/util/SslTestUtils.scala b/core/src/test/scala/org/apache/spark/util/SslTestUtils.scala new file mode 100644 index 00000000000..dd71a68f625 --- /dev/null +++ b/core/src/test/scala/org/apache/spark/util/SslTestUtils.scala @@ -0,0 +1,35 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark + +import org.apache.spark.network.ssl.SslSampleConfigs + +object SslTestUtils { + + /** + * Updates a SparkConf to contain SSL configurations + * + * @param conf The config to update + * @return The passed in SparkConf with SSL configurations added + */ + def updateWithSSLConfig(conf: SparkConf): SparkConf = { + SslSampleConfigs.createDefaultConfigMap().entrySet(). + forEach(entry => conf.set(entry.getKey, entry.getValue)) + conf + } +} diff --git a/resource-managers/yarn/src/test/scala/org/apache/spark/network/yarn/SslYarnShuffleServiceSuite.scala b/resource-managers/yarn/src/test/scala/org/apache/spark/network/yarn/SslYarnShuffleServiceSuite.scala new file mode 100644 index 00000000000..322d6bfdb7c --- /dev/null +++ b/resource-managers/yarn/src/test/scala/org/apache/spark/network/yarn/SslYarnShuffleServiceSuite.scala @@ -0,0 +1,34 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.yarn + +import org.apache.spark.network.ssl.SslSampleConfigs + +class SslYarnShuffleServiceWithRocksDBBackendSuite + extends YarnShuffleServiceWithRocksDBBackendSuite { + + /** + * Override to add "spark.ssl.rpc.*" configuration parameters... + */ + override def beforeEach(): Unit = { + super.beforeEach() + // Same as SSLTestUtils.updateWithSSLConfig(), which is not available to import here. + SslSampleConfigs.createDefaultConfigMap().entrySet(). + forEach(entry => yarnConfig.set(entry.getKey, entry.getValue)) + } +} --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org