mridulm commented on code in PR #43541:
URL: https://github.com/apache/spark/pull/43541#discussion_r1374076018


##########
common/network-common/src/main/java/org/apache/spark/network/TransportContext.java:
##########
@@ -189,15 +204,32 @@ public TransportChannelHandler 
initializePipeline(SocketChannel channel) {
    */
   public TransportChannelHandler initializePipeline(
       SocketChannel channel,
-      RpcHandler channelRpcHandler) {
+      RpcHandler channelRpcHandler,
+      boolean isClient) {
     try {
       TransportChannelHandler channelHandler = createChannelHandler(channel, 
channelRpcHandler);
       ChannelPipeline pipeline = channel.pipeline();
+
       if (nettyLogger.getLoggingHandler() != null) {
         pipeline.addLast("loggingHandler", nettyLogger.getLoggingHandler());
       }
+
+      if (sslEncryptionEnabled()) {
+        SslHandler sslHandler;
+        try {
+          sslHandler = new SslHandler(
+            sslFactory.createSSLEngine(isClient, pipeline.channel().alloc()));

Review Comment:
   nit:
   ```suggestion
               sslFactory.createSSLEngine(isClient, channel.alloc()));
   ```



##########
common/network-common/src/main/java/org/apache/spark/network/client/TransportClientFactory.java:
##########
@@ -293,6 +296,26 @@ public void initChannel(SocketChannel ch) {
     } else if (cf.cause() != null) {
       throw new IOException(String.format("Failed to connect to %s", address), 
cf.cause());
     }
+    if (context.sslEncryptionEnabled()) {
+      final SslHandler sslHandler = 
cf.channel().pipeline().get(SslHandler.class);
+      Future<Channel> future = sslHandler.handshakeFuture().addListener(
+        new GenericFutureListener<Future<Channel>>() {
+          @Override
+          public void operationComplete(final Future<Channel> handshakeFuture) 
{
+            if (handshakeFuture.isSuccess()) {
+              logger.debug("{} successfully completed TLS handshake to ", 
address);
+            } else {
+              if (logger.isDebugEnabled()) {
+                logger.debug(
+                  "failed to complete TLS handshake to " + address,
+                  handshakeFuture.cause());
+              }
+              cf.channel().close();
+            }
+          }
+      });
+      future.await(conf.connectionTimeoutMs());

Review Comment:
   Throw exception when await fails ? (after closing connection)



##########
common/network-common/src/main/java/org/apache/spark/network/TransportContext.java:
##########
@@ -189,15 +204,32 @@ public TransportChannelHandler 
initializePipeline(SocketChannel channel) {
    */
   public TransportChannelHandler initializePipeline(
       SocketChannel channel,
-      RpcHandler channelRpcHandler) {
+      RpcHandler channelRpcHandler,
+      boolean isClient) {
     try {
       TransportChannelHandler channelHandler = createChannelHandler(channel, 
channelRpcHandler);
       ChannelPipeline pipeline = channel.pipeline();
+
       if (nettyLogger.getLoggingHandler() != null) {
         pipeline.addLast("loggingHandler", nettyLogger.getLoggingHandler());
       }
+
+      if (sslEncryptionEnabled()) {
+        SslHandler sslHandler;
+        try {
+          sslHandler = new SslHandler(
+            sslFactory.createSSLEngine(isClient, pipeline.channel().alloc()));
+        } catch (Exception e) {
+          throw new RuntimeException("Error creating Netty SslHandler", e);
+        }
+        pipeline.addFirst("NettySslEncryptionHandler", sslHandler);
+        // Cannot use zero-copy with HTTPS, so we add in our 
ChunkedWriteHandler just before the
+        // MessageEncoder
+        pipeline.addLast("chunkedWriter", new ChunkedWriteHandler());

Review Comment:
   `addFirst` and `addLast` for `sslHandler` should be the same at this point.
   But, if we want to do `addFirst`, then perhaps ensure `ChunkedWriteHandler` 
is added with `addAfter` `sslHandler` ?
   
   



##########
common/network-common/src/main/java/org/apache/spark/network/client/TransportClientFactory.java:
##########
@@ -293,6 +296,26 @@ public void initChannel(SocketChannel ch) {
     } else if (cf.cause() != null) {
       throw new IOException(String.format("Failed to connect to %s", address), 
cf.cause());
     }
+    if (context.sslEncryptionEnabled()) {
+      final SslHandler sslHandler = 
cf.channel().pipeline().get(SslHandler.class);
+      Future<Channel> future = sslHandler.handshakeFuture().addListener(
+        new GenericFutureListener<Future<Channel>>() {
+          @Override
+          public void operationComplete(final Future<Channel> handshakeFuture) 
{
+            if (handshakeFuture.isSuccess()) {
+              logger.debug("{} successfully completed TLS handshake to ", 
address);
+            } else {
+              if (logger.isDebugEnabled()) {
+                logger.debug(

Review Comment:
   Do we want to make this `info` instead ? I am assuming it wont be noisy, and 
when it does fail, it is something we want to know about ?



##########
common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleIntegrationSuite.java:
##########
@@ -86,6 +85,17 @@ public class ExternalShuffleIntegrationSuite {
     new byte[54321],
   };
 
+  private static TransportConf createTransportConf(String maxRetries, String 
rddEnabled) {

Review Comment:
   nit: specify using the actual types and convert it to `String` in this 
method.
   
   ```suggestion
     private static TransportConf createTransportConf(int maxRetries, boolean 
rddEnabled) {
   ```



##########
common/network-common/src/main/java/org/apache/spark/network/TransportContext.java:
##########
@@ -189,15 +204,32 @@ public TransportChannelHandler 
initializePipeline(SocketChannel channel) {
    */
   public TransportChannelHandler initializePipeline(
       SocketChannel channel,
-      RpcHandler channelRpcHandler) {
+      RpcHandler channelRpcHandler,
+      boolean isClient) {
     try {
       TransportChannelHandler channelHandler = createChannelHandler(channel, 
channelRpcHandler);
       ChannelPipeline pipeline = channel.pipeline();
+

Review Comment:
   super nit:
   ```suggestion
   ```



##########
common/network-common/src/main/java/org/apache/spark/network/TransportContext.java:
##########
@@ -223,6 +255,33 @@ protected MessageToMessageDecoder<ByteBuf> getDecoder() {
     return DECODER;
   }
 
+  private SSLFactory createSslFactory() {
+    if (conf.sslRpcEnabled()) {
+      if (conf.sslRpcEnabledAndKeysAreValid()) {
+        return new SSLFactory.Builder()
+          .openSslEnabled(conf.sslRpcOpenSslEnabled())
+          .requestedProtocol(conf.sslRpcProtocol())
+          .requestedCiphers(conf.sslRpcRequestedCiphers())
+          .keyStore(conf.sslRpcKeyStore(), conf.sslRpcKeyStorePassword())
+          .privateKey(conf.sslRpcPrivateKey())
+          .keyPassword(conf.sslRpcKeyPassword())
+          .certChain(conf.sslRpcCertChain())
+          .trustStore(
+            conf.sslRpcTrustStore(),
+            conf.sslRpcTrustStorePassword(),
+            conf.sslRpcTrustStoreReloadingEnabled(),
+            conf.sslRpctrustStoreReloadIntervalMs())
+          .build();
+      } else {
+        logger.error("RPC SSL encryption enabled but keys not found!" +
+          "Please ensure the configured keys are present.");
+        throw new RuntimeException("RPC SSL encryption enabled but keys not 
found!");

Review Comment:
   ```suggestion
           throw new IllegalArgumentException("RPC SSL encryption enabled but 
keys not found!");
   ```



##########
common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ShuffleTransportContextSuite.java:
##########
@@ -90,15 +94,22 @@ private ByteBuf getDecodableMessageBuf(Message req) throws 
Exception {
   public void testInitializePipeline() throws IOException {
     // SPARK-43987: test that the FinalizedHandler is added to the pipeline 
only when configured
     for (boolean enabled : new boolean[]{true, false}) {
-      ShuffleTransportContext ctx = createShuffleTransportContext(enabled);
-      SocketChannel channel = new NioSocketChannel();
-      RpcHandler rpcHandler = mock(RpcHandler.class);
-      ctx.initializePipeline(channel, rpcHandler);
-      String handlerName = 
ShuffleTransportContext.FinalizedHandler.HANDLER_NAME;
-      if (enabled) {
-        Assertions.assertNotNull(channel.pipeline().get(handlerName));
-      } else {
-        Assertions.assertNull(channel.pipeline().get(handlerName));
+      for (boolean isClient: new boolean[]{true, false}) {

Review Comment:
   super nit: `isClient` -> `client`



##########
common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ShuffleTransportContext.java:
##########
@@ -52,7 +53,8 @@
  * */
 public class ShuffleTransportContext extends TransportContext {
   private static final Logger logger = 
LoggerFactory.getLogger(ShuffleTransportContext.class);
-  private static final ShuffleMessageDecoder SHUFFLE_DECODER =
+  @VisibleForTesting
+  protected static ShuffleMessageDecoder SHUFFLE_DECODER =
       new ShuffleMessageDecoder(MessageDecoder.INSTANCE);

Review Comment:
   Instead of exposing the variable, add a method to reinitialize it - and 
annotate as for use by tests.



##########
common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleIntegrationSuite.java:
##########
@@ -86,6 +85,17 @@ public class ExternalShuffleIntegrationSuite {
     new byte[54321],
   };
 
+  private static TransportConf createTransportConf(String maxRetries, String 
rddEnabled) {
+    HashMap<String, String> config = new HashMap<>();
+    config.put("spark.shuffle.io.maxRetries", maxRetries);
+    config.put(Constants.SHUFFLE_SERVICE_FETCH_RDD_ENABLED, rddEnabled);
+    return new TransportConf("shuffle", new MapConfigProvider(config));
+  }
+
+  protected TransportConf createTransportConfForFetchNoServerTest() {

Review Comment:
   It is unclear to me why this method is named this way ...



##########
common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ShuffleTransportContextSuite.java:
##########
@@ -90,15 +94,22 @@ private ByteBuf getDecodableMessageBuf(Message req) throws 
Exception {
   public void testInitializePipeline() throws IOException {
     // SPARK-43987: test that the FinalizedHandler is added to the pipeline 
only when configured
     for (boolean enabled : new boolean[]{true, false}) {
-      ShuffleTransportContext ctx = createShuffleTransportContext(enabled);
-      SocketChannel channel = new NioSocketChannel();
-      RpcHandler rpcHandler = mock(RpcHandler.class);
-      ctx.initializePipeline(channel, rpcHandler);
-      String handlerName = 
ShuffleTransportContext.FinalizedHandler.HANDLER_NAME;
-      if (enabled) {
-        Assertions.assertNotNull(channel.pipeline().get(handlerName));
-      } else {
-        Assertions.assertNull(channel.pipeline().get(handlerName));
+      for (boolean isClient: new boolean[]{true, false}) {
+        // Since the decoder is not Shareable, reset it between test runs to 
avoid errors since it's
+        // used both across ShuffleTransportContextSuite and 
SslShuffleTransportContextSuite
+        // and server/clients

Review Comment:
   The decoder is not being used here (other than configuring the pipeline) - 
why do we need to reset it ?



##########
common/network-common/src/test/java/org/apache/spark/network/SslChunkFetchIntegrationSuite.java:
##########
@@ -0,0 +1,100 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.network;
+
+import java.io.File;
+import java.io.RandomAccessFile;
+import java.nio.ByteBuffer;
+import java.util.Random;
+
+import com.google.common.io.Closeables;
+import org.junit.jupiter.api.BeforeAll;
+
+import static org.junit.jupiter.api.Assertions.*;
+
+import org.apache.spark.network.buffer.FileSegmentManagedBuffer;
+import org.apache.spark.network.buffer.ManagedBuffer;
+import org.apache.spark.network.buffer.NioManagedBuffer;
+import org.apache.spark.network.client.RpcResponseCallback;
+import org.apache.spark.network.client.TransportClient;
+import org.apache.spark.network.server.RpcHandler;
+import org.apache.spark.network.server.StreamManager;
+import org.apache.spark.network.util.TransportConf;
+import org.apache.spark.network.ssl.SslSampleConfigs;
+
+
+public class SslChunkFetchIntegrationSuite extends ChunkFetchIntegrationSuite {
+
+  @BeforeAll
+  public static void setUp() throws Exception {
+    int bufSize = 100000;
+    final ByteBuffer buf = ByteBuffer.allocate(bufSize);
+    for (int i = 0; i < bufSize; i ++) {
+      buf.put((byte) i);
+    }
+    buf.flip();
+    bufferChunk = new NioManagedBuffer(buf);
+
+    testFile = File.createTempFile("shuffle-test-file", "txt");
+    testFile.deleteOnExit();
+    RandomAccessFile fp = new RandomAccessFile(testFile, "rw");
+    boolean shouldSuppressIOException = true;
+    try {
+      byte[] fileContent = new byte[1024];
+      new Random().nextBytes(fileContent);
+      fp.write(fileContent);
+      shouldSuppressIOException = false;
+    } finally {
+      Closeables.close(fp, shouldSuppressIOException);
+    }
+
+    final TransportConf conf = new TransportConf(
+      "shuffle", 
SslSampleConfigs.createDefaultConfigProviderForRpcNamespace());
+    fileChunk = new FileSegmentManagedBuffer(conf, testFile, 10, 
testFile.length() - 25);
+
+    streamManager = new StreamManager() {
+      @Override
+      public ManagedBuffer getChunk(long streamId, int chunkIndex) {
+        assertEquals(STREAM_ID, streamId);
+        if (chunkIndex == BUFFER_CHUNK_INDEX) {
+          return new NioManagedBuffer(buf);
+        } else if (chunkIndex == FILE_CHUNK_INDEX) {
+          return new FileSegmentManagedBuffer(conf, testFile, 10, 
testFile.length() - 25);
+        } else {
+          throw new IllegalArgumentException("Invalid chunk index: " + 
chunkIndex);
+        }
+      }
+    };
+    RpcHandler handler = new RpcHandler() {
+      @Override
+      public void receive(
+          TransportClient client,
+          ByteBuffer message,
+          RpcResponseCallback callback) {
+        throw new UnsupportedOperationException();
+      }
+
+      @Override
+      public StreamManager getStreamManager() {
+        return streamManager;
+      }
+    };
+    context = new TransportContext(conf, handler);
+    server = context.createServer();
+    clientFactory = context.createClientFactory();
+  }

Review Comment:
   If I am not wrong, the only change between this and 
`ChunkFetchIntegrationSuite.setUp` is `conf` right ?
   If yes, instead of duplicating the method - pass the `conf` to a common 
static method to initialize for both Suites instead instead ?
   
   (Same comment for the other Suites too)



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org

For queries about this service, please contact Infrastructure at:
us...@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org
For additional commands, e-mail: reviews-h...@spark.apache.org

Reply via email to