mridulm commented on code in PR #43541: URL: https://github.com/apache/spark/pull/43541#discussion_r1374076018
########## common/network-common/src/main/java/org/apache/spark/network/TransportContext.java: ########## @@ -189,15 +204,32 @@ public TransportChannelHandler initializePipeline(SocketChannel channel) { */ public TransportChannelHandler initializePipeline( SocketChannel channel, - RpcHandler channelRpcHandler) { + RpcHandler channelRpcHandler, + boolean isClient) { try { TransportChannelHandler channelHandler = createChannelHandler(channel, channelRpcHandler); ChannelPipeline pipeline = channel.pipeline(); + if (nettyLogger.getLoggingHandler() != null) { pipeline.addLast("loggingHandler", nettyLogger.getLoggingHandler()); } + + if (sslEncryptionEnabled()) { + SslHandler sslHandler; + try { + sslHandler = new SslHandler( + sslFactory.createSSLEngine(isClient, pipeline.channel().alloc())); Review Comment: nit: ```suggestion sslFactory.createSSLEngine(isClient, channel.alloc())); ``` ########## common/network-common/src/main/java/org/apache/spark/network/client/TransportClientFactory.java: ########## @@ -293,6 +296,26 @@ public void initChannel(SocketChannel ch) { } else if (cf.cause() != null) { throw new IOException(String.format("Failed to connect to %s", address), cf.cause()); } + if (context.sslEncryptionEnabled()) { + final SslHandler sslHandler = cf.channel().pipeline().get(SslHandler.class); + Future<Channel> future = sslHandler.handshakeFuture().addListener( + new GenericFutureListener<Future<Channel>>() { + @Override + public void operationComplete(final Future<Channel> handshakeFuture) { + if (handshakeFuture.isSuccess()) { + logger.debug("{} successfully completed TLS handshake to ", address); + } else { + if (logger.isDebugEnabled()) { + logger.debug( + "failed to complete TLS handshake to " + address, + handshakeFuture.cause()); + } + cf.channel().close(); + } + } + }); + future.await(conf.connectionTimeoutMs()); Review Comment: Throw exception when await fails ? (after closing connection) ########## common/network-common/src/main/java/org/apache/spark/network/TransportContext.java: ########## @@ -189,15 +204,32 @@ public TransportChannelHandler initializePipeline(SocketChannel channel) { */ public TransportChannelHandler initializePipeline( SocketChannel channel, - RpcHandler channelRpcHandler) { + RpcHandler channelRpcHandler, + boolean isClient) { try { TransportChannelHandler channelHandler = createChannelHandler(channel, channelRpcHandler); ChannelPipeline pipeline = channel.pipeline(); + if (nettyLogger.getLoggingHandler() != null) { pipeline.addLast("loggingHandler", nettyLogger.getLoggingHandler()); } + + if (sslEncryptionEnabled()) { + SslHandler sslHandler; + try { + sslHandler = new SslHandler( + sslFactory.createSSLEngine(isClient, pipeline.channel().alloc())); + } catch (Exception e) { + throw new RuntimeException("Error creating Netty SslHandler", e); + } + pipeline.addFirst("NettySslEncryptionHandler", sslHandler); + // Cannot use zero-copy with HTTPS, so we add in our ChunkedWriteHandler just before the + // MessageEncoder + pipeline.addLast("chunkedWriter", new ChunkedWriteHandler()); Review Comment: `addFirst` and `addLast` for `sslHandler` should be the same at this point. But, if we want to do `addFirst`, then perhaps ensure `ChunkedWriteHandler` is added with `addAfter` `sslHandler` ? ########## common/network-common/src/main/java/org/apache/spark/network/client/TransportClientFactory.java: ########## @@ -293,6 +296,26 @@ public void initChannel(SocketChannel ch) { } else if (cf.cause() != null) { throw new IOException(String.format("Failed to connect to %s", address), cf.cause()); } + if (context.sslEncryptionEnabled()) { + final SslHandler sslHandler = cf.channel().pipeline().get(SslHandler.class); + Future<Channel> future = sslHandler.handshakeFuture().addListener( + new GenericFutureListener<Future<Channel>>() { + @Override + public void operationComplete(final Future<Channel> handshakeFuture) { + if (handshakeFuture.isSuccess()) { + logger.debug("{} successfully completed TLS handshake to ", address); + } else { + if (logger.isDebugEnabled()) { + logger.debug( Review Comment: Do we want to make this `info` instead ? I am assuming it wont be noisy, and when it does fail, it is something we want to know about ? ########## common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleIntegrationSuite.java: ########## @@ -86,6 +85,17 @@ public class ExternalShuffleIntegrationSuite { new byte[54321], }; + private static TransportConf createTransportConf(String maxRetries, String rddEnabled) { Review Comment: nit: specify using the actual types and convert it to `String` in this method. ```suggestion private static TransportConf createTransportConf(int maxRetries, boolean rddEnabled) { ``` ########## common/network-common/src/main/java/org/apache/spark/network/TransportContext.java: ########## @@ -189,15 +204,32 @@ public TransportChannelHandler initializePipeline(SocketChannel channel) { */ public TransportChannelHandler initializePipeline( SocketChannel channel, - RpcHandler channelRpcHandler) { + RpcHandler channelRpcHandler, + boolean isClient) { try { TransportChannelHandler channelHandler = createChannelHandler(channel, channelRpcHandler); ChannelPipeline pipeline = channel.pipeline(); + Review Comment: super nit: ```suggestion ``` ########## common/network-common/src/main/java/org/apache/spark/network/TransportContext.java: ########## @@ -223,6 +255,33 @@ protected MessageToMessageDecoder<ByteBuf> getDecoder() { return DECODER; } + private SSLFactory createSslFactory() { + if (conf.sslRpcEnabled()) { + if (conf.sslRpcEnabledAndKeysAreValid()) { + return new SSLFactory.Builder() + .openSslEnabled(conf.sslRpcOpenSslEnabled()) + .requestedProtocol(conf.sslRpcProtocol()) + .requestedCiphers(conf.sslRpcRequestedCiphers()) + .keyStore(conf.sslRpcKeyStore(), conf.sslRpcKeyStorePassword()) + .privateKey(conf.sslRpcPrivateKey()) + .keyPassword(conf.sslRpcKeyPassword()) + .certChain(conf.sslRpcCertChain()) + .trustStore( + conf.sslRpcTrustStore(), + conf.sslRpcTrustStorePassword(), + conf.sslRpcTrustStoreReloadingEnabled(), + conf.sslRpctrustStoreReloadIntervalMs()) + .build(); + } else { + logger.error("RPC SSL encryption enabled but keys not found!" + + "Please ensure the configured keys are present."); + throw new RuntimeException("RPC SSL encryption enabled but keys not found!"); Review Comment: ```suggestion throw new IllegalArgumentException("RPC SSL encryption enabled but keys not found!"); ``` ########## common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ShuffleTransportContextSuite.java: ########## @@ -90,15 +94,22 @@ private ByteBuf getDecodableMessageBuf(Message req) throws Exception { public void testInitializePipeline() throws IOException { // SPARK-43987: test that the FinalizedHandler is added to the pipeline only when configured for (boolean enabled : new boolean[]{true, false}) { - ShuffleTransportContext ctx = createShuffleTransportContext(enabled); - SocketChannel channel = new NioSocketChannel(); - RpcHandler rpcHandler = mock(RpcHandler.class); - ctx.initializePipeline(channel, rpcHandler); - String handlerName = ShuffleTransportContext.FinalizedHandler.HANDLER_NAME; - if (enabled) { - Assertions.assertNotNull(channel.pipeline().get(handlerName)); - } else { - Assertions.assertNull(channel.pipeline().get(handlerName)); + for (boolean isClient: new boolean[]{true, false}) { Review Comment: super nit: `isClient` -> `client` ########## common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ShuffleTransportContext.java: ########## @@ -52,7 +53,8 @@ * */ public class ShuffleTransportContext extends TransportContext { private static final Logger logger = LoggerFactory.getLogger(ShuffleTransportContext.class); - private static final ShuffleMessageDecoder SHUFFLE_DECODER = + @VisibleForTesting + protected static ShuffleMessageDecoder SHUFFLE_DECODER = new ShuffleMessageDecoder(MessageDecoder.INSTANCE); Review Comment: Instead of exposing the variable, add a method to reinitialize it - and annotate as for use by tests. ########## common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleIntegrationSuite.java: ########## @@ -86,6 +85,17 @@ public class ExternalShuffleIntegrationSuite { new byte[54321], }; + private static TransportConf createTransportConf(String maxRetries, String rddEnabled) { + HashMap<String, String> config = new HashMap<>(); + config.put("spark.shuffle.io.maxRetries", maxRetries); + config.put(Constants.SHUFFLE_SERVICE_FETCH_RDD_ENABLED, rddEnabled); + return new TransportConf("shuffle", new MapConfigProvider(config)); + } + + protected TransportConf createTransportConfForFetchNoServerTest() { Review Comment: It is unclear to me why this method is named this way ... ########## common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ShuffleTransportContextSuite.java: ########## @@ -90,15 +94,22 @@ private ByteBuf getDecodableMessageBuf(Message req) throws Exception { public void testInitializePipeline() throws IOException { // SPARK-43987: test that the FinalizedHandler is added to the pipeline only when configured for (boolean enabled : new boolean[]{true, false}) { - ShuffleTransportContext ctx = createShuffleTransportContext(enabled); - SocketChannel channel = new NioSocketChannel(); - RpcHandler rpcHandler = mock(RpcHandler.class); - ctx.initializePipeline(channel, rpcHandler); - String handlerName = ShuffleTransportContext.FinalizedHandler.HANDLER_NAME; - if (enabled) { - Assertions.assertNotNull(channel.pipeline().get(handlerName)); - } else { - Assertions.assertNull(channel.pipeline().get(handlerName)); + for (boolean isClient: new boolean[]{true, false}) { + // Since the decoder is not Shareable, reset it between test runs to avoid errors since it's + // used both across ShuffleTransportContextSuite and SslShuffleTransportContextSuite + // and server/clients Review Comment: The decoder is not being used here (other than configuring the pipeline) - why do we need to reset it ? ########## common/network-common/src/test/java/org/apache/spark/network/SslChunkFetchIntegrationSuite.java: ########## @@ -0,0 +1,100 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.network; + +import java.io.File; +import java.io.RandomAccessFile; +import java.nio.ByteBuffer; +import java.util.Random; + +import com.google.common.io.Closeables; +import org.junit.jupiter.api.BeforeAll; + +import static org.junit.jupiter.api.Assertions.*; + +import org.apache.spark.network.buffer.FileSegmentManagedBuffer; +import org.apache.spark.network.buffer.ManagedBuffer; +import org.apache.spark.network.buffer.NioManagedBuffer; +import org.apache.spark.network.client.RpcResponseCallback; +import org.apache.spark.network.client.TransportClient; +import org.apache.spark.network.server.RpcHandler; +import org.apache.spark.network.server.StreamManager; +import org.apache.spark.network.util.TransportConf; +import org.apache.spark.network.ssl.SslSampleConfigs; + + +public class SslChunkFetchIntegrationSuite extends ChunkFetchIntegrationSuite { + + @BeforeAll + public static void setUp() throws Exception { + int bufSize = 100000; + final ByteBuffer buf = ByteBuffer.allocate(bufSize); + for (int i = 0; i < bufSize; i ++) { + buf.put((byte) i); + } + buf.flip(); + bufferChunk = new NioManagedBuffer(buf); + + testFile = File.createTempFile("shuffle-test-file", "txt"); + testFile.deleteOnExit(); + RandomAccessFile fp = new RandomAccessFile(testFile, "rw"); + boolean shouldSuppressIOException = true; + try { + byte[] fileContent = new byte[1024]; + new Random().nextBytes(fileContent); + fp.write(fileContent); + shouldSuppressIOException = false; + } finally { + Closeables.close(fp, shouldSuppressIOException); + } + + final TransportConf conf = new TransportConf( + "shuffle", SslSampleConfigs.createDefaultConfigProviderForRpcNamespace()); + fileChunk = new FileSegmentManagedBuffer(conf, testFile, 10, testFile.length() - 25); + + streamManager = new StreamManager() { + @Override + public ManagedBuffer getChunk(long streamId, int chunkIndex) { + assertEquals(STREAM_ID, streamId); + if (chunkIndex == BUFFER_CHUNK_INDEX) { + return new NioManagedBuffer(buf); + } else if (chunkIndex == FILE_CHUNK_INDEX) { + return new FileSegmentManagedBuffer(conf, testFile, 10, testFile.length() - 25); + } else { + throw new IllegalArgumentException("Invalid chunk index: " + chunkIndex); + } + } + }; + RpcHandler handler = new RpcHandler() { + @Override + public void receive( + TransportClient client, + ByteBuffer message, + RpcResponseCallback callback) { + throw new UnsupportedOperationException(); + } + + @Override + public StreamManager getStreamManager() { + return streamManager; + } + }; + context = new TransportContext(conf, handler); + server = context.createServer(); + clientFactory = context.createClientFactory(); + } Review Comment: If I am not wrong, the only change between this and `ChunkFetchIntegrationSuite.setUp` is `conf` right ? If yes, instead of duplicating the method - pass the `conf` to a common static method to initialize for both Suites instead instead ? (Same comment for the other Suites too) -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org For queries about this service, please contact Infrastructure at: us...@infra.apache.org --------------------------------------------------------------------- To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org For additional commands, e-mail: reviews-h...@spark.apache.org