This is an automated email from the ASF dual-hosted git repository. wenchen pushed a commit to branch branch-2.4 in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/branch-2.4 by this push: new 298e4fa [SPARK-27275][CORE] Fix potential corruption in EncryptedMessage.transferTo (2.4) 298e4fa is described below commit 298e4fa6f8054c54e246f91b70d62174ccdb9413 Author: Shixiong Zhu <zsxw...@gmail.com> AuthorDate: Thu Mar 28 11:13:11 2019 -0700 [SPARK-27275][CORE] Fix potential corruption in EncryptedMessage.transferTo (2.4) ## What changes were proposed in this pull request? Backport https://github.com/apache/spark/pull/24211 to 2.4 ## How was this patch tested? Jenkins Closes #24229 from zsxwing/SPARK-27275-2.4. Authored-by: Shixiong Zhu <zsxw...@gmail.com> Signed-off-by: Wenchen Fan <wenc...@databricks.com> --- .../spark/network/crypto/TransportCipher.java | 53 ++++++++++---- .../spark/network/crypto/AuthEngineSuite.java | 85 ++++++++++++++++++++++ .../spark/network/crypto/AuthIntegrationSuite.java | 47 ++++++++++-- 3 files changed, 167 insertions(+), 18 deletions(-) diff --git a/common/network-common/src/main/java/org/apache/spark/network/crypto/TransportCipher.java b/common/network-common/src/main/java/org/apache/spark/network/crypto/TransportCipher.java index b64e4b7..0b674cc 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/crypto/TransportCipher.java +++ b/common/network-common/src/main/java/org/apache/spark/network/crypto/TransportCipher.java @@ -44,7 +44,8 @@ public class TransportCipher { @VisibleForTesting static final String ENCRYPTION_HANDLER_NAME = "TransportEncryption"; private static final String DECRYPTION_HANDLER_NAME = "TransportDecryption"; - private static final int STREAM_BUFFER_SIZE = 1024 * 32; + @VisibleForTesting + static final int STREAM_BUFFER_SIZE = 1024 * 32; private final Properties conf; private final String cipher; @@ -84,7 +85,8 @@ public class TransportCipher { return outIv; } - private CryptoOutputStream createOutputStream(WritableByteChannel ch) throws IOException { + @VisibleForTesting + CryptoOutputStream createOutputStream(WritableByteChannel ch) throws IOException { return new CryptoOutputStream(cipher, conf, ch, key, new IvParameterSpec(outIv)); } @@ -104,7 +106,8 @@ public class TransportCipher { .addFirst(DECRYPTION_HANDLER_NAME, new DecryptionHandler(this)); } - private static class EncryptionHandler extends ChannelOutboundHandlerAdapter { + @VisibleForTesting + static class EncryptionHandler extends ChannelOutboundHandlerAdapter { private final ByteArrayWritableChannel byteChannel; private final CryptoOutputStream cos; @@ -116,7 +119,12 @@ public class TransportCipher { @Override public void write(ChannelHandlerContext ctx, Object msg, ChannelPromise promise) throws Exception { - ctx.write(new EncryptedMessage(cos, msg, byteChannel), promise); + ctx.write(createEncryptedMessage(msg), promise); + } + + @VisibleForTesting + EncryptedMessage createEncryptedMessage(Object msg) { + return new EncryptedMessage(cos, msg, byteChannel); } @Override @@ -161,10 +169,12 @@ public class TransportCipher { } } - private static class EncryptedMessage extends AbstractFileRegion { + @VisibleForTesting + static class EncryptedMessage extends AbstractFileRegion { private final boolean isByteBuf; private final ByteBuf buf; private final FileRegion region; + private final long count; private long transferred; private CryptoOutputStream cos; @@ -186,11 +196,12 @@ public class TransportCipher { this.byteRawChannel = new ByteArrayWritableChannel(STREAM_BUFFER_SIZE); this.cos = cos; this.byteEncChannel = ch; + this.count = isByteBuf ? buf.readableBytes() : region.count(); } @Override public long count() { - return isByteBuf ? buf.readableBytes() : region.count(); + return count; } @Override @@ -242,22 +253,38 @@ public class TransportCipher { public long transferTo(WritableByteChannel target, long position) throws IOException { Preconditions.checkArgument(position == transferred(), "Invalid position."); + if (transferred == count) { + return 0; + } + + long totalBytesWritten = 0L; do { if (currentEncrypted == null) { encryptMore(); } - int bytesWritten = currentEncrypted.remaining(); - target.write(currentEncrypted); - bytesWritten -= currentEncrypted.remaining(); - transferred += bytesWritten; - if (!currentEncrypted.hasRemaining()) { + long remaining = currentEncrypted.remaining(); + if (remaining == 0) { + // Just for safety to avoid endless loop. It usually won't happen, but since the + // underlying `region.transferTo` is allowed to transfer 0 bytes, we should handle it for + // safety. currentEncrypted = null; byteEncChannel.reset(); + return totalBytesWritten; } - } while (transferred < count()); - return transferred; + long bytesWritten = target.write(currentEncrypted); + totalBytesWritten += bytesWritten; + transferred += bytesWritten; + if (bytesWritten < remaining) { + // break as the underlying buffer in "target" is full + break; + } + currentEncrypted = null; + byteEncChannel.reset(); + } while (transferred < count); + + return totalBytesWritten; } private void encryptMore() throws IOException { diff --git a/common/network-common/src/test/java/org/apache/spark/network/crypto/AuthEngineSuite.java b/common/network-common/src/test/java/org/apache/spark/network/crypto/AuthEngineSuite.java index a3519fe..46b6305 100644 --- a/common/network-common/src/test/java/org/apache/spark/network/crypto/AuthEngineSuite.java +++ b/common/network-common/src/test/java/org/apache/spark/network/crypto/AuthEngineSuite.java @@ -17,13 +17,24 @@ package org.apache.spark.network.crypto; +import java.nio.ByteBuffer; +import java.nio.channels.WritableByteChannel; import java.util.Arrays; +import java.util.Random; + import static java.nio.charset.StandardCharsets.UTF_8; +import io.netty.buffer.ByteBuf; +import io.netty.buffer.Unpooled; +import io.netty.channel.FileRegion; import org.junit.BeforeClass; import org.junit.Test; +import org.mockito.invocation.InvocationOnMock; +import org.mockito.stubbing.Answer; import static org.junit.Assert.*; +import static org.mockito.Mockito.*; +import org.apache.spark.network.util.ByteArrayWritableChannel; import org.apache.spark.network.util.MapConfigProvider; import org.apache.spark.network.util.TransportConf; @@ -104,4 +115,78 @@ public class AuthEngineSuite { challenge.cipher, challenge.keyLength, challenge.nonce, badChallenge)); } + @Test + public void testEncryptedMessage() throws Exception { + AuthEngine client = new AuthEngine("appId", "secret", conf); + AuthEngine server = new AuthEngine("appId", "secret", conf); + try { + ClientChallenge clientChallenge = client.challenge(); + ServerResponse serverResponse = server.respond(clientChallenge); + client.validate(serverResponse); + + TransportCipher cipher = server.sessionCipher(); + TransportCipher.EncryptionHandler handler = new TransportCipher.EncryptionHandler(cipher); + + byte[] data = new byte[TransportCipher.STREAM_BUFFER_SIZE + 1]; + new Random().nextBytes(data); + ByteBuf buf = Unpooled.wrappedBuffer(data); + + ByteArrayWritableChannel channel = new ByteArrayWritableChannel(data.length); + TransportCipher.EncryptedMessage emsg = handler.createEncryptedMessage(buf); + while (emsg.transfered() < emsg.count()) { + emsg.transferTo(channel, emsg.transfered()); + } + assertEquals(data.length, channel.length()); + } finally { + client.close(); + server.close(); + } + } + + @Test + public void testEncryptedMessageWhenTransferringZeroBytes() throws Exception { + AuthEngine client = new AuthEngine("appId", "secret", conf); + AuthEngine server = new AuthEngine("appId", "secret", conf); + try { + ClientChallenge clientChallenge = client.challenge(); + ServerResponse serverResponse = server.respond(clientChallenge); + client.validate(serverResponse); + + TransportCipher cipher = server.sessionCipher(); + TransportCipher.EncryptionHandler handler = new TransportCipher.EncryptionHandler(cipher); + + int testDataLength = 4; + FileRegion region = mock(FileRegion.class); + when(region.count()).thenReturn((long) testDataLength); + // Make `region.transferTo` do nothing in first call and transfer 4 bytes in the second one. + when(region.transferTo(any(), anyLong())).thenAnswer(new Answer<Long>() { + + private boolean firstTime = true; + + @Override + public Long answer(InvocationOnMock invocationOnMock) throws Throwable { + if (firstTime) { + firstTime = false; + return 0L; + } else { + WritableByteChannel channel = + invocationOnMock.getArgumentAt(0, WritableByteChannel.class); + channel.write(ByteBuffer.wrap(new byte[testDataLength])); + return (long) testDataLength; + } + } + }); + + TransportCipher.EncryptedMessage emsg = handler.createEncryptedMessage(region); + ByteArrayWritableChannel channel = new ByteArrayWritableChannel(testDataLength); + // "transferTo" should act correctly when the underlying FileRegion transfers 0 bytes. + assertEquals(0L, emsg.transferTo(channel, emsg.transfered())); + assertEquals(testDataLength, emsg.transferTo(channel, emsg.transfered())); + assertEquals(emsg.transfered(), emsg.count()); + assertEquals(4, channel.length()); + } finally { + client.close(); + server.close(); + } + } } diff --git a/common/network-common/src/test/java/org/apache/spark/network/crypto/AuthIntegrationSuite.java b/common/network-common/src/test/java/org/apache/spark/network/crypto/AuthIntegrationSuite.java index 8751944..73418a9 100644 --- a/common/network-common/src/test/java/org/apache/spark/network/crypto/AuthIntegrationSuite.java +++ b/common/network-common/src/test/java/org/apache/spark/network/crypto/AuthIntegrationSuite.java @@ -124,6 +124,42 @@ public class AuthIntegrationSuite { } } + @Test + public void testLargeMessageEncryption() throws Exception { + // Use a big length to create a message that cannot be put into the encryption buffer completely + final int testErrorMessageLength = TransportCipher.STREAM_BUFFER_SIZE; + ctx = new AuthTestCtx(new RpcHandler() { + @Override + public void receive( + TransportClient client, + ByteBuffer message, + RpcResponseCallback callback) { + char[] longMessage = new char[testErrorMessageLength]; + Arrays.fill(longMessage, 'D'); + callback.onFailure(new RuntimeException(new String(longMessage))); + } + + @Override + public StreamManager getStreamManager() { + return null; + } + }); + ctx.createServer("secret"); + ctx.createClient("secret"); + + try { + ctx.client.sendRpcSync(JavaUtils.stringToBytes("Ping"), 5000); + fail("Should have failed unencrypted RPC."); + } catch (Exception e) { + assertTrue(ctx.authRpcHandler.doDelegate); + assertTrue(e.getMessage() + " is not an expected error", e.getMessage().contains("DDDDD")); + // Verify we receive the complete error message + int messageStart = e.getMessage().indexOf("DDDDD"); + int messageEnd = e.getMessage().lastIndexOf("DDDDD") + 5; + assertEquals(testErrorMessageLength, messageEnd - messageStart); + } + } + private class AuthTestCtx { private final String appId = "testAppId"; @@ -136,10 +172,7 @@ public class AuthIntegrationSuite { volatile AuthRpcHandler authRpcHandler; AuthTestCtx() throws Exception { - Map<String, String> testConf = ImmutableMap.of("spark.network.crypto.enabled", "true"); - this.conf = new TransportConf("rpc", new MapConfigProvider(testConf)); - - RpcHandler rpcHandler = new RpcHandler() { + this(new RpcHandler() { @Override public void receive( TransportClient client, @@ -153,8 +186,12 @@ public class AuthIntegrationSuite { public StreamManager getStreamManager() { return null; } - }; + }); + } + AuthTestCtx(RpcHandler rpcHandler) throws Exception { + Map<String, String> testConf = ImmutableMap.of("spark.network.crypto.enabled", "true"); + this.conf = new TransportConf("rpc", new MapConfigProvider(testConf)); this.ctx = new TransportContext(conf, rpcHandler); } --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org