This is an automated email from the ASF dual-hosted git repository.

mridulm80 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 884f6f71172 [SPARK-45544][CORE] Integrate SSL support into 
TransportContext
884f6f71172 is described below

commit 884f6f71172156ccc7d95ed022c8fb8baadc3c0a
Author: Hasnain Lakhani <hasnain.lakh...@databricks.com>
AuthorDate: Sun Oct 29 20:58:18 2023 -0500

    [SPARK-45544][CORE] Integrate SSL support into TransportContext
    
    ### What changes were proposed in this pull request?
    
    This integrates SSL support into TransportContext and related modules so 
that the RPC SSL functionality can work when properly configured.
    
    ### Why are the changes needed?
    
    This is needed in order to support SSL for RPC connections.
    
    ### Does this PR introduce _any_ user-facing change?
    
    No
    
    ### How was this patch tested?
    
    CI
    
    Ran the following tests:
    
    ```
    build/sbt -P yarn
    > project network-common
    > testOnly
    > project network-shuffle
    > testOnly
    > project core
    > testOnly *Ssl*
    > project yarn
    > testOnly 
org.apache.spark.network.yarn.SslYarnShuffleServiceWithRocksDBBackendSuite
    ```
    
    I verified traffic was encrypted using TLS using two mechanisms:
    
    * Enabled trace level logging for Netty and JDK SSL and saw logs confirming 
TLS handshakes were happening
    * I ran wireshark on my machine and snooped on traffic while sending 
queries shuffling a fixed string. Without any encryption, I could find that 
string in the network traffic. With this encryption enabled, that string did 
not show up, and wireshark logs confirmed a TLS handshake was happening.
    
    ### Was this patch authored or co-authored using generative AI tooling?
    
    No
    
    Closes #43541 from hasnain-db/spark-tls-final.
    
    Authored-by: Hasnain Lakhani <hasnain.lakh...@databricks.com>
    Signed-off-by: Mridul Muralidharan <mridul<at>gmail.com>
---
 .../org/apache/spark/network/TransportContext.java | 70 ++++++++++++++++++++--
 .../network/client/TransportClientFactory.java     | 26 +++++++-
 .../spark/network/server/TransportServer.java      |  2 +-
 .../apache/spark/network/util/TransportConf.java   |  8 ---
 .../spark/network/ChunkFetchIntegrationSuite.java  |  6 +-
 .../network/SslChunkFetchIntegrationSuite.java     | 22 ++++---
 .../client/SslTransportClientFactorySuite.java     | 29 +++++----
 .../client/TransportClientFactorySuite.java        |  8 +--
 .../network/shuffle/ShuffleTransportContext.java   | 10 ++--
 .../shuffle/ExternalShuffleIntegrationSuite.java   | 29 +++++----
 .../shuffle/ExternalShuffleSecuritySuite.java      | 14 ++++-
 .../shuffle/ShuffleTransportContextSuite.java      | 33 +++++-----
 .../SslExternalShuffleIntegrationSuite.java        | 44 ++++++++++++++
 .../shuffle/SslExternalShuffleSecuritySuite.java   | 35 +++++++----
 .../shuffle/SslShuffleTransportContextSuite.java   | 28 +++++----
 .../network/yarn/SslYarnShuffleServiceSuite.scala  |  2 +-
 16 files changed, 265 insertions(+), 101 deletions(-)

diff --git 
a/common/network-common/src/main/java/org/apache/spark/network/TransportContext.java
 
b/common/network-common/src/main/java/org/apache/spark/network/TransportContext.java
index 51d074a4ddb..90ca4f4c46a 100644
--- 
a/common/network-common/src/main/java/org/apache/spark/network/TransportContext.java
+++ 
b/common/network-common/src/main/java/org/apache/spark/network/TransportContext.java
@@ -23,13 +23,17 @@ import io.netty.handler.codec.MessageToMessageDecoder;
 import java.io.Closeable;
 import java.util.ArrayList;
 import java.util.List;
+import javax.annotation.Nullable;
 
 import com.codahale.metrics.Counter;
 import io.netty.channel.Channel;
 import io.netty.channel.ChannelPipeline;
 import io.netty.channel.EventLoopGroup;
 import io.netty.channel.socket.SocketChannel;
+import io.netty.handler.ssl.SslHandler;
+import io.netty.handler.stream.ChunkedWriteHandler;
 import io.netty.handler.timeout.IdleStateHandler;
+import io.netty.handler.codec.MessageToMessageEncoder;
 import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
 
@@ -37,6 +41,8 @@ import org.apache.spark.network.client.TransportClient;
 import org.apache.spark.network.client.TransportClientBootstrap;
 import org.apache.spark.network.client.TransportClientFactory;
 import org.apache.spark.network.client.TransportResponseHandler;
+import org.apache.spark.network.protocol.Message;
+import org.apache.spark.network.protocol.SslMessageEncoder;
 import org.apache.spark.network.protocol.MessageDecoder;
 import org.apache.spark.network.protocol.MessageEncoder;
 import org.apache.spark.network.server.ChunkFetchRequestHandler;
@@ -45,6 +51,7 @@ import 
org.apache.spark.network.server.TransportChannelHandler;
 import org.apache.spark.network.server.TransportRequestHandler;
 import org.apache.spark.network.server.TransportServer;
 import org.apache.spark.network.server.TransportServerBootstrap;
+import org.apache.spark.network.ssl.SSLFactory;
 import org.apache.spark.network.util.IOMode;
 import org.apache.spark.network.util.NettyUtils;
 import org.apache.spark.network.util.NettyLogger;
@@ -72,6 +79,8 @@ public class TransportContext implements Closeable {
   private final TransportConf conf;
   private final RpcHandler rpcHandler;
   private final boolean closeIdleConnections;
+  // Non-null if SSL is enabled, null otherwise.
+  @Nullable private final SSLFactory sslFactory;
   // Number of registered connections to the shuffle service
   private Counter registeredConnections = new Counter();
 
@@ -87,7 +96,8 @@ public class TransportContext implements Closeable {
    * RPC to load it and cause to load the non-exist matcher class again. JVM 
will report
    * `ClassCircularityError` to prevent such infinite recursion. (See 
SPARK-17714)
    */
-  private static final MessageEncoder ENCODER = MessageEncoder.INSTANCE;
+  private static final MessageToMessageEncoder<Message> ENCODER = 
MessageEncoder.INSTANCE;
+  private static final MessageToMessageEncoder<Message> SSL_ENCODER = 
SslMessageEncoder.INSTANCE;
   private static final MessageDecoder DECODER = MessageDecoder.INSTANCE;
 
   // Separate thread pool for handling ChunkFetchRequest. This helps to enable 
throttling
@@ -125,6 +135,7 @@ public class TransportContext implements Closeable {
     this.conf = conf;
     this.rpcHandler = rpcHandler;
     this.closeIdleConnections = closeIdleConnections;
+    this.sslFactory = createSslFactory();
 
     if (conf.getModuleName() != null &&
         conf.getModuleName().equalsIgnoreCase("shuffle") &&
@@ -171,8 +182,12 @@ public class TransportContext implements Closeable {
     return createServer(0, new ArrayList<>());
   }
 
-  public TransportChannelHandler initializePipeline(SocketChannel channel) {
-    return initializePipeline(channel, rpcHandler);
+  public TransportChannelHandler initializePipeline(SocketChannel channel, 
boolean isClient) {
+    return initializePipeline(channel, rpcHandler, isClient);
+  }
+
+  public boolean sslEncryptionEnabled() {
+    return this.sslFactory != null;
   }
 
   /**
@@ -189,15 +204,30 @@ public class TransportContext implements Closeable {
    */
   public TransportChannelHandler initializePipeline(
       SocketChannel channel,
-      RpcHandler channelRpcHandler) {
+      RpcHandler channelRpcHandler,
+      boolean isClient) {
     try {
       TransportChannelHandler channelHandler = createChannelHandler(channel, 
channelRpcHandler);
       ChannelPipeline pipeline = channel.pipeline();
       if (nettyLogger.getLoggingHandler() != null) {
         pipeline.addLast("loggingHandler", nettyLogger.getLoggingHandler());
       }
+
+      if (sslEncryptionEnabled()) {
+        SslHandler sslHandler;
+        try {
+          sslHandler = new SslHandler(sslFactory.createSSLEngine(isClient, 
channel.alloc()));
+        } catch (Exception e) {
+          throw new IllegalStateException("Error creating Netty SslHandler", 
e);
+        }
+        pipeline.addFirst("NettySslEncryptionHandler", sslHandler);
+        // Cannot use zero-copy with HTTPS, so we add in our 
ChunkedWriteHandler just before the
+        // MessageEncoder
+        pipeline.addLast("chunkedWriter", new ChunkedWriteHandler());
+      }
+
       pipeline
-        .addLast("encoder", ENCODER)
+        .addLast("encoder", sslEncryptionEnabled()? SSL_ENCODER : ENCODER)
         .addLast(TransportFrameDecoder.HANDLER_NAME, 
NettyUtils.createFrameDecoder())
         .addLast("decoder", getDecoder())
         .addLast("idleStateHandler",
@@ -223,6 +253,33 @@ public class TransportContext implements Closeable {
     return DECODER;
   }
 
+  private SSLFactory createSslFactory() {
+    if (conf.sslRpcEnabled()) {
+      if (conf.sslRpcEnabledAndKeysAreValid()) {
+        return new SSLFactory.Builder()
+          .openSslEnabled(conf.sslRpcOpenSslEnabled())
+          .requestedProtocol(conf.sslRpcProtocol())
+          .requestedCiphers(conf.sslRpcRequestedCiphers())
+          .keyStore(conf.sslRpcKeyStore(), conf.sslRpcKeyStorePassword())
+          .privateKey(conf.sslRpcPrivateKey())
+          .keyPassword(conf.sslRpcKeyPassword())
+          .certChain(conf.sslRpcCertChain())
+          .trustStore(
+            conf.sslRpcTrustStore(),
+            conf.sslRpcTrustStorePassword(),
+            conf.sslRpcTrustStoreReloadingEnabled(),
+            conf.sslRpctrustStoreReloadIntervalMs())
+          .build();
+      } else {
+        logger.error("RPC SSL encryption enabled but keys not found!" +
+          "Please ensure the configured keys are present.");
+        throw new IllegalArgumentException("RPC SSL encryption enabled but 
keys not found!");
+      }
+    } else {
+      return null;
+    }
+  }
+
   /**
    * Creates the server- and client-side handler which is used to handle both 
RequestMessages and
    * ResponseMessages. The channel is expected to have been successfully 
created, though certain
@@ -255,5 +312,8 @@ public class TransportContext implements Closeable {
     if (chunkFetchWorkers != null) {
       chunkFetchWorkers.shutdownGracefully();
     }
+    if (sslFactory != null) {
+      sslFactory.destroy();
+    }
   }
 }
diff --git 
a/common/network-common/src/main/java/org/apache/spark/network/client/TransportClientFactory.java
 
b/common/network-common/src/main/java/org/apache/spark/network/client/TransportClientFactory.java
index 4c1efd69206..fd48020caac 100644
--- 
a/common/network-common/src/main/java/org/apache/spark/network/client/TransportClientFactory.java
+++ 
b/common/network-common/src/main/java/org/apache/spark/network/client/TransportClientFactory.java
@@ -39,6 +39,9 @@ import io.netty.channel.ChannelInitializer;
 import io.netty.channel.ChannelOption;
 import io.netty.channel.EventLoopGroup;
 import io.netty.channel.socket.SocketChannel;
+import io.netty.handler.ssl.SslHandler;
+import io.netty.util.concurrent.Future;
+import io.netty.util.concurrent.GenericFutureListener;
 import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
 
@@ -268,7 +271,7 @@ public class TransportClientFactory implements Closeable {
     bootstrap.handler(new ChannelInitializer<SocketChannel>() {
       @Override
       public void initChannel(SocketChannel ch) {
-        TransportChannelHandler clientHandler = context.initializePipeline(ch);
+        TransportChannelHandler clientHandler = context.initializePipeline(ch, 
true);
         clientRef.set(clientHandler.getClient());
         channelRef.set(ch);
       }
@@ -293,6 +296,27 @@ public class TransportClientFactory implements Closeable {
     } else if (cf.cause() != null) {
       throw new IOException(String.format("Failed to connect to %s", address), 
cf.cause());
     }
+    if (context.sslEncryptionEnabled()) {
+      final SslHandler sslHandler = 
cf.channel().pipeline().get(SslHandler.class);
+      Future<Channel> future = sslHandler.handshakeFuture().addListener(
+        new GenericFutureListener<Future<Channel>>() {
+          @Override
+          public void operationComplete(final Future<Channel> handshakeFuture) 
{
+            if (handshakeFuture.isSuccess()) {
+              logger.debug("{} successfully completed TLS handshake to ", 
address);
+            } else {
+              logger.info(
+                "failed to complete TLS handshake to " + address, 
handshakeFuture.cause());
+              cf.channel().close();
+            }
+          }
+      });
+      if (!future.await(conf.connectionTimeoutMs())) {
+        cf.channel().close();
+        throw new IOException(
+          String.format("Failed to connect to %s within connection timeout", 
address));
+      }
+    }
 
     TransportClient client = clientRef.get();
     Channel channel = channelRef.get();
diff --git 
a/common/network-common/src/main/java/org/apache/spark/network/server/TransportServer.java
 
b/common/network-common/src/main/java/org/apache/spark/network/server/TransportServer.java
index 5b5b3f9d901..6f2e4b8a502 100644
--- 
a/common/network-common/src/main/java/org/apache/spark/network/server/TransportServer.java
+++ 
b/common/network-common/src/main/java/org/apache/spark/network/server/TransportServer.java
@@ -140,7 +140,7 @@ public class TransportServer implements Closeable {
         for (TransportServerBootstrap bootstrap : bootstraps) {
           rpcHandler = bootstrap.doBootstrap(ch, rpcHandler);
         }
-        context.initializePipeline(ch, rpcHandler);
+        context.initializePipeline(ch, rpcHandler, false);
       }
     });
 
diff --git 
a/common/network-common/src/main/java/org/apache/spark/network/util/TransportConf.java
 
b/common/network-common/src/main/java/org/apache/spark/network/util/TransportConf.java
index 3ebb38e310f..eb85d2bb561 100644
--- 
a/common/network-common/src/main/java/org/apache/spark/network/util/TransportConf.java
+++ 
b/common/network-common/src/main/java/org/apache/spark/network/util/TransportConf.java
@@ -401,14 +401,6 @@ public class TransportConf {
     }
   }
 
-  /**
-   * If we can dangerously fallback to unencrypted connections if RPC over SSL 
is enabled
-   * but the key files are not present
-   */
-  public boolean sslRpcDangerouslyFallbackIfKeysNotPresent() {
-    return 
conf.getBoolean("spark.ssl.rpc.dangerouslyFallbackIfKeysNotPresent", false);
-  }
-
   /**
    * Flag indicating whether to share the pooled ByteBuf allocators between 
the different Netty
    * channels. If enabled then only two pooled ByteBuf allocators are created: 
one where caching
diff --git 
a/common/network-common/src/test/java/org/apache/spark/network/ChunkFetchIntegrationSuite.java
 
b/common/network-common/src/test/java/org/apache/spark/network/ChunkFetchIntegrationSuite.java
index 2026d3b9524..576a106934f 100644
--- 
a/common/network-common/src/test/java/org/apache/spark/network/ChunkFetchIntegrationSuite.java
+++ 
b/common/network-common/src/test/java/org/apache/spark/network/ChunkFetchIntegrationSuite.java
@@ -65,8 +65,13 @@ public class ChunkFetchIntegrationSuite {
   static ManagedBuffer bufferChunk;
   static ManagedBuffer fileChunk;
 
+  // This is split out so it can be invoked in a subclass with a different 
config
   @BeforeAll
   public static void setUp() throws Exception {
+    doSetUpWithConfig(new TransportConf("shuffle", MapConfigProvider.EMPTY));
+  }
+
+  public static void doSetUpWithConfig(final TransportConf conf) throws 
Exception {
     int bufSize = 100000;
     final ByteBuffer buf = ByteBuffer.allocate(bufSize);
     for (int i = 0; i < bufSize; i ++) {
@@ -88,7 +93,6 @@ public class ChunkFetchIntegrationSuite {
       Closeables.close(fp, shouldSuppressIOException);
     }
 
-    final TransportConf conf = new TransportConf("shuffle", 
MapConfigProvider.EMPTY);
     fileChunk = new FileSegmentManagedBuffer(conf, testFile, 10, 
testFile.length() - 25);
 
     streamManager = new StreamManager() {
diff --git 
a/resource-managers/yarn/src/test/scala/org/apache/spark/network/yarn/SslYarnShuffleServiceSuite.scala
 
b/common/network-common/src/test/java/org/apache/spark/network/SslChunkFetchIntegrationSuite.java
similarity index 59%
copy from 
resource-managers/yarn/src/test/scala/org/apache/spark/network/yarn/SslYarnShuffleServiceSuite.scala
copy to 
common/network-common/src/test/java/org/apache/spark/network/SslChunkFetchIntegrationSuite.java
index 322d6bfdb7c..783ffd4b8c1 100644
--- 
a/resource-managers/yarn/src/test/scala/org/apache/spark/network/yarn/SslYarnShuffleServiceSuite.scala
+++ 
b/common/network-common/src/test/java/org/apache/spark/network/SslChunkFetchIntegrationSuite.java
@@ -14,21 +14,19 @@
  * See the License for the specific language governing permissions and
  * limitations under the License.
  */
+package org.apache.spark.network;
 
-package org.apache.spark.network.yarn
+import org.junit.jupiter.api.BeforeAll;
 
-import org.apache.spark.network.ssl.SslSampleConfigs
+import org.apache.spark.network.util.TransportConf;
+import org.apache.spark.network.ssl.SslSampleConfigs;
 
-class SslYarnShuffleServiceWithRocksDBBackendSuite
-  extends YarnShuffleServiceWithRocksDBBackendSuite {
 
-  /**
-   * Override to add "spark.ssl.rpc.*" configuration parameters...
-   */
-  override def beforeEach(): Unit = {
-    super.beforeEach()
-    // Same as SSLTestUtils.updateWithSSLConfig(), which is not available to 
import here.
-    SslSampleConfigs.createDefaultConfigMap().entrySet().
-      forEach(entry => yarnConfig.set(entry.getKey, entry.getValue))
+public class SslChunkFetchIntegrationSuite extends ChunkFetchIntegrationSuite {
+
+  @BeforeAll
+  public static void setUp() throws Exception {
+    doSetUpWithConfig(new TransportConf(
+      "shuffle", 
SslSampleConfigs.createDefaultConfigProviderForRpcNamespace()));
   }
 }
diff --git 
a/resource-managers/yarn/src/test/scala/org/apache/spark/network/yarn/SslYarnShuffleServiceSuite.scala
 
b/common/network-common/src/test/java/org/apache/spark/network/client/SslTransportClientFactorySuite.java
similarity index 51%
copy from 
resource-managers/yarn/src/test/scala/org/apache/spark/network/yarn/SslYarnShuffleServiceSuite.scala
copy to 
common/network-common/src/test/java/org/apache/spark/network/client/SslTransportClientFactorySuite.java
index 322d6bfdb7c..79b76b633f9 100644
--- 
a/resource-managers/yarn/src/test/scala/org/apache/spark/network/yarn/SslYarnShuffleServiceSuite.scala
+++ 
b/common/network-common/src/test/java/org/apache/spark/network/client/SslTransportClientFactorySuite.java
@@ -15,20 +15,25 @@
  * limitations under the License.
  */
 
-package org.apache.spark.network.yarn
+package org.apache.spark.network.client;
 
-import org.apache.spark.network.ssl.SslSampleConfigs
+import org.junit.jupiter.api.BeforeEach;
 
-class SslYarnShuffleServiceWithRocksDBBackendSuite
-  extends YarnShuffleServiceWithRocksDBBackendSuite {
+import org.apache.spark.network.ssl.SslSampleConfigs;
+import org.apache.spark.network.server.NoOpRpcHandler;
+import org.apache.spark.network.server.RpcHandler;
+import org.apache.spark.network.util.TransportConf;
+import org.apache.spark.network.TransportContext;
 
-  /**
-   * Override to add "spark.ssl.rpc.*" configuration parameters...
-   */
-  override def beforeEach(): Unit = {
-    super.beforeEach()
-    // Same as SSLTestUtils.updateWithSSLConfig(), which is not available to 
import here.
-    SslSampleConfigs.createDefaultConfigMap().entrySet().
-      forEach(entry => yarnConfig.set(entry.getKey, entry.getValue))
+public class SslTransportClientFactorySuite extends 
TransportClientFactorySuite {
+
+  @BeforeEach
+  public void setUp() {
+    conf = new TransportConf(
+      "shuffle", 
SslSampleConfigs.createDefaultConfigProviderForRpcNamespace());
+    RpcHandler rpcHandler = new NoOpRpcHandler();
+    context = new TransportContext(conf, rpcHandler);
+    server1 = context.createServer();
+    server2 = context.createServer();
   }
 }
diff --git 
a/common/network-common/src/test/java/org/apache/spark/network/client/TransportClientFactorySuite.java
 
b/common/network-common/src/test/java/org/apache/spark/network/client/TransportClientFactorySuite.java
index 49a2d570d96..b57f0be920c 100644
--- 
a/common/network-common/src/test/java/org/apache/spark/network/client/TransportClientFactorySuite.java
+++ 
b/common/network-common/src/test/java/org/apache/spark/network/client/TransportClientFactorySuite.java
@@ -44,10 +44,10 @@ import org.apache.spark.network.util.TransportConf;
 import static org.junit.jupiter.api.Assertions.*;
 
 public class TransportClientFactorySuite {
-  private TransportConf conf;
-  private TransportContext context;
-  private TransportServer server1;
-  private TransportServer server2;
+  protected TransportConf conf;
+  protected TransportContext context;
+  protected TransportServer server1;
+  protected TransportServer server2;
 
   @BeforeEach
   public void setUp() {
diff --git 
a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ShuffleTransportContext.java
 
b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ShuffleTransportContext.java
index e0971d49510..feaaa570b73 100644
--- 
a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ShuffleTransportContext.java
+++ 
b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ShuffleTransportContext.java
@@ -22,6 +22,7 @@ import java.util.List;
 
 import io.netty.buffer.ByteBuf;
 import io.netty.buffer.Unpooled;
+import io.netty.channel.ChannelHandler;
 import io.netty.channel.ChannelHandlerContext;
 import io.netty.channel.EventLoopGroup;
 import io.netty.channel.SimpleChannelInboundHandler;
@@ -81,16 +82,16 @@ public class ShuffleTransportContext extends 
TransportContext {
   }
 
   @Override
-  public TransportChannelHandler initializePipeline(SocketChannel channel) {
-    TransportChannelHandler ch = super.initializePipeline(channel);
+  public TransportChannelHandler initializePipeline(SocketChannel channel, 
boolean isClient) {
+    TransportChannelHandler ch = super.initializePipeline(channel, isClient);
     addHandlerToPipeline(channel, ch);
     return ch;
   }
 
   @Override
   public TransportChannelHandler initializePipeline(SocketChannel channel,
-      RpcHandler channelRpcHandler) {
-    TransportChannelHandler ch = super.initializePipeline(channel, 
channelRpcHandler);
+      RpcHandler channelRpcHandler, boolean isClient) {
+    TransportChannelHandler ch = super.initializePipeline(channel, 
channelRpcHandler, isClient);
     addHandlerToPipeline(channel, ch);
     return ch;
   }
@@ -112,6 +113,7 @@ public class ShuffleTransportContext extends 
TransportContext {
     return finalizeWorkers == null ? super.getDecoder() : SHUFFLE_DECODER;
   }
 
+  @ChannelHandler.Sharable
   static class ShuffleMessageDecoder extends MessageToMessageDecoder<ByteBuf> {
 
     private final MessageDecoder delegate;
diff --git 
a/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleIntegrationSuite.java
 
b/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleIntegrationSuite.java
index b5ffa30f62d..73cb133f17e 100644
--- 
a/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleIntegrationSuite.java
+++ 
b/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleIntegrationSuite.java
@@ -32,7 +32,6 @@ import java.util.concurrent.Future;
 import java.util.concurrent.Semaphore;
 import java.util.concurrent.TimeUnit;
 
-import com.google.common.collect.ImmutableMap;
 import com.google.common.collect.Sets;
 import org.apache.spark.network.buffer.FileSegmentManagedBuffer;
 import org.apache.spark.network.server.OneForOneStreamManager;
@@ -57,11 +56,11 @@ public class ExternalShuffleIntegrationSuite {
   private static final String APP_ID = "app-id";
   private static final String SORT_MANAGER = 
"org.apache.spark.shuffle.sort.SortShuffleManager";
 
-  private static final int RDD_ID = 1;
-  private static final int SPLIT_INDEX_VALID_BLOCK = 0;
+  protected static final int RDD_ID = 1;
+  protected static final int SPLIT_INDEX_VALID_BLOCK = 0;
   private static final int SPLIT_INDEX_MISSING_FILE = 1;
-  private static final int SPLIT_INDEX_CORRUPT_LENGTH = 2;
-  private static final int SPLIT_INDEX_VALID_BLOCK_TO_RM = 3;
+  protected static final int SPLIT_INDEX_CORRUPT_LENGTH = 2;
+  protected static final int SPLIT_INDEX_VALID_BLOCK_TO_RM = 3;
   private static final int SPLIT_INDEX_MISSING_BLOCK_TO_RM = 4;
 
   // Executor 0 is sort-based
@@ -86,8 +85,20 @@ public class ExternalShuffleIntegrationSuite {
     new byte[54321],
   };
 
+  private static TransportConf createTransportConf(int maxRetries, boolean 
rddEnabled) {
+    HashMap<String, String> config = new HashMap<>();
+    config.put("spark.shuffle.io.maxRetries", String.valueOf(maxRetries));
+    config.put(Constants.SHUFFLE_SERVICE_FETCH_RDD_ENABLED, 
String.valueOf(rddEnabled));
+    return new TransportConf("shuffle", new MapConfigProvider(config));
+  }
+
+  // This is split out so it can be invoked in a subclass with a different 
config
   @BeforeAll
   public static void beforeAll() throws IOException {
+    doBeforeAllWithConfig(createTransportConf(0, true));
+  }
+
+  public static void doBeforeAllWithConfig(TransportConf transportConf) throws 
IOException {
     Random rand = new Random();
 
     for (byte[] block : exec0Blocks) {
@@ -105,10 +116,7 @@ public class ExternalShuffleIntegrationSuite {
     dataContext0.insertCachedRddData(RDD_ID, SPLIT_INDEX_VALID_BLOCK, 
exec0RddBlockValid);
     dataContext0.insertCachedRddData(RDD_ID, SPLIT_INDEX_VALID_BLOCK_TO_RM, 
exec0RddBlockToRemove);
 
-    HashMap<String, String> config = new HashMap<>();
-    config.put("spark.shuffle.io.maxRetries", "0");
-    config.put(Constants.SHUFFLE_SERVICE_FETCH_RDD_ENABLED, "true");
-    conf = new TransportConf("shuffle", new MapConfigProvider(config));
+    conf = transportConf;
     handler = new ExternalBlockHandler(
       new OneForOneStreamManager(),
       new ExternalShuffleBlockResolver(conf, null) {
@@ -319,8 +327,7 @@ public class ExternalShuffleIntegrationSuite {
 
   @Test
   public void testFetchNoServer() throws Exception {
-    TransportConf clientConf = new TransportConf("shuffle",
-      new MapConfigProvider(ImmutableMap.of("spark.shuffle.io.maxRetries", 
"0")));
+    TransportConf clientConf = createTransportConf(0, false);
     registerExecutor("exec-0", dataContext0.createExecutorInfo(SORT_MANAGER));
     FetchResult execFetch = fetchBlocks("exec-0",
       new String[]{"shuffle_1_0_0", "shuffle_1_0_1"}, clientConf, 1 /* port 
*/);
diff --git 
a/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleSecuritySuite.java
 
b/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleSecuritySuite.java
index b8beec303ae..76f82800c50 100644
--- 
a/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleSecuritySuite.java
+++ 
b/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleSecuritySuite.java
@@ -39,10 +39,19 @@ import org.apache.spark.network.util.TransportConf;
 
 public class ExternalShuffleSecuritySuite {
 
-  TransportConf conf = new TransportConf("shuffle", MapConfigProvider.EMPTY);
+  TransportConf conf = createTransportConf(false);
   TransportServer server;
   TransportContext transportContext;
 
+  protected TransportConf createTransportConf(boolean encrypt) {
+    if (encrypt) {
+      return new TransportConf("shuffle", new MapConfigProvider(
+        ImmutableMap.of("spark.authenticate.enableSaslEncryption", "true")));
+    } else {
+      return new TransportConf("shuffle", MapConfigProvider.EMPTY);
+    }
+  }
+
   @BeforeEach
   public void beforeEach() throws IOException {
     transportContext = new TransportContext(conf, new 
ExternalBlockHandler(conf, null));
@@ -92,8 +101,7 @@ public class ExternalShuffleSecuritySuite {
         throws IOException, InterruptedException {
     TransportConf testConf = conf;
     if (encrypt) {
-      testConf = new TransportConf("shuffle", new MapConfigProvider(
-        ImmutableMap.of("spark.authenticate.enableSaslEncryption", "true")));
+      testConf = createTransportConf(encrypt);
     }
 
     try (ExternalBlockStoreClient client =
diff --git 
a/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ShuffleTransportContextSuite.java
 
b/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ShuffleTransportContextSuite.java
index 5484e8131a8..de164474766 100644
--- 
a/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ShuffleTransportContextSuite.java
+++ 
b/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ShuffleTransportContextSuite.java
@@ -60,13 +60,16 @@ public class ShuffleTransportContextSuite {
     blockHandler = mock(ExternalBlockHandler.class);
   }
 
-  ShuffleTransportContext createShuffleTransportContext(boolean 
separateFinalizeThread)
-      throws IOException {
+  protected TransportConf createTransportConf(boolean separateFinalizeThread) {
     Map<String, String> configs = new HashMap<>();
     configs.put("spark.shuffle.server.finalizeShuffleMergeThreadsPercent",
-        separateFinalizeThread ? "1" : "0");
-    TransportConf transportConf = new TransportConf("shuffle",
-        new MapConfigProvider(configs));
+      separateFinalizeThread ? "1" : "0");
+    return new TransportConf("shuffle", new MapConfigProvider(configs));
+  }
+
+  ShuffleTransportContext createShuffleTransportContext(boolean 
separateFinalizeThread)
+      throws IOException {
+    TransportConf transportConf = createTransportConf(separateFinalizeThread);
     return new ShuffleTransportContext(transportConf, blockHandler, true);
   }
 
@@ -90,15 +93,17 @@ public class ShuffleTransportContextSuite {
   public void testInitializePipeline() throws IOException {
     // SPARK-43987: test that the FinalizedHandler is added to the pipeline 
only when configured
     for (boolean enabled : new boolean[]{true, false}) {
-      ShuffleTransportContext ctx = createShuffleTransportContext(enabled);
-      SocketChannel channel = new NioSocketChannel();
-      RpcHandler rpcHandler = mock(RpcHandler.class);
-      ctx.initializePipeline(channel, rpcHandler);
-      String handlerName = 
ShuffleTransportContext.FinalizedHandler.HANDLER_NAME;
-      if (enabled) {
-        Assertions.assertNotNull(channel.pipeline().get(handlerName));
-      } else {
-        Assertions.assertNull(channel.pipeline().get(handlerName));
+      for (boolean client: new boolean[]{true, false}) {
+        ShuffleTransportContext ctx = createShuffleTransportContext(enabled);
+        SocketChannel channel = new NioSocketChannel();
+        RpcHandler rpcHandler = mock(RpcHandler.class);
+        ctx.initializePipeline(channel, rpcHandler, client);
+        String handlerName = 
ShuffleTransportContext.FinalizedHandler.HANDLER_NAME;
+        if (enabled) {
+          Assertions.assertNotNull(channel.pipeline().get(handlerName));
+        } else {
+          Assertions.assertNull(channel.pipeline().get(handlerName));
+        }
       }
     }
   }
diff --git 
a/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/SslExternalShuffleIntegrationSuite.java
 
b/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/SslExternalShuffleIntegrationSuite.java
new file mode 100644
index 00000000000..3591ccad150
--- /dev/null
+++ 
b/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/SslExternalShuffleIntegrationSuite.java
@@ -0,0 +1,44 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.network.shuffle;
+
+import java.io.IOException;
+import java.util.HashMap;
+
+import org.junit.jupiter.api.BeforeAll;
+
+import org.apache.spark.network.util.TransportConf;
+import org.apache.spark.network.ssl.SslSampleConfigs;
+
+public class SslExternalShuffleIntegrationSuite extends 
ExternalShuffleIntegrationSuite {
+
+  private static TransportConf createTransportConf(int maxRetries, boolean 
rddEnabled) {
+    HashMap<String, String> config = new HashMap<>();
+    config.put("spark.shuffle.io.maxRetries", String.valueOf(maxRetries));
+    config.put(Constants.SHUFFLE_SERVICE_FETCH_RDD_ENABLED, 
String.valueOf(rddEnabled));
+    return new TransportConf(
+      "shuffle",
+      
SslSampleConfigs.createDefaultConfigProviderForRpcNamespaceWithAdditionalEntries(config)
+    );
+  }
+
+  @BeforeAll
+  public static void beforeAll() throws IOException {
+    doBeforeAllWithConfig(createTransportConf(0, true));
+  }
+}
diff --git 
a/resource-managers/yarn/src/test/scala/org/apache/spark/network/yarn/SslYarnShuffleServiceSuite.scala
 
b/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/SslExternalShuffleSecuritySuite.java
similarity index 50%
copy from 
resource-managers/yarn/src/test/scala/org/apache/spark/network/yarn/SslYarnShuffleServiceSuite.scala
copy to 
common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/SslExternalShuffleSecuritySuite.java
index 322d6bfdb7c..061d63dbcd7 100644
--- 
a/resource-managers/yarn/src/test/scala/org/apache/spark/network/yarn/SslYarnShuffleServiceSuite.scala
+++ 
b/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/SslExternalShuffleSecuritySuite.java
@@ -15,20 +15,31 @@
  * limitations under the License.
  */
 
-package org.apache.spark.network.yarn
+package org.apache.spark.network.shuffle;
 
-import org.apache.spark.network.ssl.SslSampleConfigs
+import com.google.common.collect.ImmutableMap;
 
-class SslYarnShuffleServiceWithRocksDBBackendSuite
-  extends YarnShuffleServiceWithRocksDBBackendSuite {
+import org.apache.spark.network.ssl.SslSampleConfigs;
+import org.apache.spark.network.util.TransportConf;
 
-  /**
-   * Override to add "spark.ssl.rpc.*" configuration parameters...
-   */
-  override def beforeEach(): Unit = {
-    super.beforeEach()
-    // Same as SSLTestUtils.updateWithSSLConfig(), which is not available to 
import here.
-    SslSampleConfigs.createDefaultConfigMap().entrySet().
-      forEach(entry => yarnConfig.set(entry.getKey, entry.getValue))
+public class SslExternalShuffleSecuritySuite extends 
ExternalShuffleSecuritySuite {
+
+  @Override
+  protected TransportConf createTransportConf(boolean encrypt) {
+    if (encrypt) {
+      return new TransportConf(
+        "shuffle",
+        
SslSampleConfigs.createDefaultConfigProviderForRpcNamespaceWithAdditionalEntries(
+          ImmutableMap.of(
+          "spark.authenticate.enableSaslEncryption",
+          "true")
+        )
+      );
+    } else {
+      return new TransportConf(
+        "shuffle",
+        SslSampleConfigs.createDefaultConfigProviderForRpcNamespace()
+      );
+    }
   }
 }
diff --git 
a/resource-managers/yarn/src/test/scala/org/apache/spark/network/yarn/SslYarnShuffleServiceSuite.scala
 
b/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/SslShuffleTransportContextSuite.java
similarity index 55%
copy from 
resource-managers/yarn/src/test/scala/org/apache/spark/network/yarn/SslYarnShuffleServiceSuite.scala
copy to 
common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/SslShuffleTransportContextSuite.java
index 322d6bfdb7c..51463bbad55 100644
--- 
a/resource-managers/yarn/src/test/scala/org/apache/spark/network/yarn/SslYarnShuffleServiceSuite.scala
+++ 
b/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/SslShuffleTransportContextSuite.java
@@ -15,20 +15,24 @@
  * limitations under the License.
  */
 
-package org.apache.spark.network.yarn
+package org.apache.spark.network.shuffle;
 
-import org.apache.spark.network.ssl.SslSampleConfigs
+import com.google.common.collect.ImmutableMap;
 
-class SslYarnShuffleServiceWithRocksDBBackendSuite
-  extends YarnShuffleServiceWithRocksDBBackendSuite {
+import org.apache.spark.network.ssl.SslSampleConfigs;
+import org.apache.spark.network.util.TransportConf;
 
-  /**
-   * Override to add "spark.ssl.rpc.*" configuration parameters...
-   */
-  override def beforeEach(): Unit = {
-    super.beforeEach()
-    // Same as SSLTestUtils.updateWithSSLConfig(), which is not available to 
import here.
-    SslSampleConfigs.createDefaultConfigMap().entrySet().
-      forEach(entry => yarnConfig.set(entry.getKey, entry.getValue))
+public class SslShuffleTransportContextSuite extends 
ShuffleTransportContextSuite {
+
+  @Override
+  protected TransportConf createTransportConf(boolean separateFinalizeThread) {
+    return new TransportConf(
+      "shuffle",
+      
SslSampleConfigs.createDefaultConfigProviderForRpcNamespaceWithAdditionalEntries(
+        ImmutableMap.of(
+          "spark.shuffle.server.finalizeShuffleMergeThreadsPercent",
+          separateFinalizeThread ? "1" : "0")
+      )
+    );
   }
 }
diff --git 
a/resource-managers/yarn/src/test/scala/org/apache/spark/network/yarn/SslYarnShuffleServiceSuite.scala
 
b/resource-managers/yarn/src/test/scala/org/apache/spark/network/yarn/SslYarnShuffleServiceSuite.scala
index 322d6bfdb7c..06b91faf44a 100644
--- 
a/resource-managers/yarn/src/test/scala/org/apache/spark/network/yarn/SslYarnShuffleServiceSuite.scala
+++ 
b/resource-managers/yarn/src/test/scala/org/apache/spark/network/yarn/SslYarnShuffleServiceSuite.scala
@@ -28,7 +28,7 @@ class SslYarnShuffleServiceWithRocksDBBackendSuite
   override def beforeEach(): Unit = {
     super.beforeEach()
     // Same as SSLTestUtils.updateWithSSLConfig(), which is not available to 
import here.
-    SslSampleConfigs.createDefaultConfigMap().entrySet().
+    SslSampleConfigs.createDefaultConfigMapForRpcNamespace().entrySet().
       forEach(entry => yarnConfig.set(entry.getKey, entry.getValue))
   }
 }


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org


Reply via email to