This is an automated email from the ASF dual-hosted git repository.

wenchen pushed a commit to branch branch-2.4
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/branch-2.4 by this push:
     new 298e4fa  [SPARK-27275][CORE] Fix potential corruption in 
EncryptedMessage.transferTo (2.4)
298e4fa is described below

commit 298e4fa6f8054c54e246f91b70d62174ccdb9413
Author: Shixiong Zhu <zsxw...@gmail.com>
AuthorDate: Thu Mar 28 11:13:11 2019 -0700

    [SPARK-27275][CORE] Fix potential corruption in EncryptedMessage.transferTo 
(2.4)
    
    ## What changes were proposed in this pull request?
    
    Backport https://github.com/apache/spark/pull/24211 to 2.4
    
    ## How was this patch tested?
    
    Jenkins
    
    Closes #24229 from zsxwing/SPARK-27275-2.4.
    
    Authored-by: Shixiong Zhu <zsxw...@gmail.com>
    Signed-off-by: Wenchen Fan <wenc...@databricks.com>
---
 .../spark/network/crypto/TransportCipher.java      | 53 ++++++++++----
 .../spark/network/crypto/AuthEngineSuite.java      | 85 ++++++++++++++++++++++
 .../spark/network/crypto/AuthIntegrationSuite.java | 47 ++++++++++--
 3 files changed, 167 insertions(+), 18 deletions(-)

diff --git 
a/common/network-common/src/main/java/org/apache/spark/network/crypto/TransportCipher.java
 
b/common/network-common/src/main/java/org/apache/spark/network/crypto/TransportCipher.java
index b64e4b7..0b674cc 100644
--- 
a/common/network-common/src/main/java/org/apache/spark/network/crypto/TransportCipher.java
+++ 
b/common/network-common/src/main/java/org/apache/spark/network/crypto/TransportCipher.java
@@ -44,7 +44,8 @@ public class TransportCipher {
   @VisibleForTesting
   static final String ENCRYPTION_HANDLER_NAME = "TransportEncryption";
   private static final String DECRYPTION_HANDLER_NAME = "TransportDecryption";
-  private static final int STREAM_BUFFER_SIZE = 1024 * 32;
+  @VisibleForTesting
+  static final int STREAM_BUFFER_SIZE = 1024 * 32;
 
   private final Properties conf;
   private final String cipher;
@@ -84,7 +85,8 @@ public class TransportCipher {
     return outIv;
   }
 
-  private CryptoOutputStream createOutputStream(WritableByteChannel ch) throws 
IOException {
+  @VisibleForTesting
+  CryptoOutputStream createOutputStream(WritableByteChannel ch) throws 
IOException {
     return new CryptoOutputStream(cipher, conf, ch, key, new 
IvParameterSpec(outIv));
   }
 
@@ -104,7 +106,8 @@ public class TransportCipher {
       .addFirst(DECRYPTION_HANDLER_NAME, new DecryptionHandler(this));
   }
 
-  private static class EncryptionHandler extends ChannelOutboundHandlerAdapter 
{
+  @VisibleForTesting
+  static class EncryptionHandler extends ChannelOutboundHandlerAdapter {
     private final ByteArrayWritableChannel byteChannel;
     private final CryptoOutputStream cos;
 
@@ -116,7 +119,12 @@ public class TransportCipher {
     @Override
     public void write(ChannelHandlerContext ctx, Object msg, ChannelPromise 
promise)
       throws Exception {
-      ctx.write(new EncryptedMessage(cos, msg, byteChannel), promise);
+      ctx.write(createEncryptedMessage(msg), promise);
+    }
+
+    @VisibleForTesting
+    EncryptedMessage createEncryptedMessage(Object msg) {
+      return new EncryptedMessage(cos, msg, byteChannel);
     }
 
     @Override
@@ -161,10 +169,12 @@ public class TransportCipher {
     }
   }
 
-  private static class EncryptedMessage extends AbstractFileRegion {
+  @VisibleForTesting
+  static class EncryptedMessage extends AbstractFileRegion {
     private final boolean isByteBuf;
     private final ByteBuf buf;
     private final FileRegion region;
+    private final long count;
     private long transferred;
     private CryptoOutputStream cos;
 
@@ -186,11 +196,12 @@ public class TransportCipher {
       this.byteRawChannel = new ByteArrayWritableChannel(STREAM_BUFFER_SIZE);
       this.cos = cos;
       this.byteEncChannel = ch;
+      this.count = isByteBuf ? buf.readableBytes() : region.count();
     }
 
     @Override
     public long count() {
-      return isByteBuf ? buf.readableBytes() : region.count();
+      return count;
     }
 
     @Override
@@ -242,22 +253,38 @@ public class TransportCipher {
     public long transferTo(WritableByteChannel target, long position) throws 
IOException {
       Preconditions.checkArgument(position == transferred(), "Invalid 
position.");
 
+      if (transferred == count) {
+        return 0;
+      }
+
+      long totalBytesWritten = 0L;
       do {
         if (currentEncrypted == null) {
           encryptMore();
         }
 
-        int bytesWritten = currentEncrypted.remaining();
-        target.write(currentEncrypted);
-        bytesWritten -= currentEncrypted.remaining();
-        transferred += bytesWritten;
-        if (!currentEncrypted.hasRemaining()) {
+        long remaining = currentEncrypted.remaining();
+        if (remaining == 0)  {
+          // Just for safety to avoid endless loop. It usually won't happen, 
but since the
+          // underlying `region.transferTo` is allowed to transfer 0 bytes, we 
should handle it for
+          // safety.
           currentEncrypted = null;
           byteEncChannel.reset();
+          return totalBytesWritten;
         }
-      } while (transferred < count());
 
-      return transferred;
+        long bytesWritten = target.write(currentEncrypted);
+        totalBytesWritten += bytesWritten;
+        transferred += bytesWritten;
+        if (bytesWritten < remaining) {
+          // break as the underlying buffer in "target" is full
+          break;
+        }
+        currentEncrypted = null;
+        byteEncChannel.reset();
+      } while (transferred < count);
+
+      return totalBytesWritten;
     }
 
     private void encryptMore() throws IOException {
diff --git 
a/common/network-common/src/test/java/org/apache/spark/network/crypto/AuthEngineSuite.java
 
b/common/network-common/src/test/java/org/apache/spark/network/crypto/AuthEngineSuite.java
index a3519fe..46b6305 100644
--- 
a/common/network-common/src/test/java/org/apache/spark/network/crypto/AuthEngineSuite.java
+++ 
b/common/network-common/src/test/java/org/apache/spark/network/crypto/AuthEngineSuite.java
@@ -17,13 +17,24 @@
 
 package org.apache.spark.network.crypto;
 
+import java.nio.ByteBuffer;
+import java.nio.channels.WritableByteChannel;
 import java.util.Arrays;
+import java.util.Random;
+
 import static java.nio.charset.StandardCharsets.UTF_8;
 
+import io.netty.buffer.ByteBuf;
+import io.netty.buffer.Unpooled;
+import io.netty.channel.FileRegion;
 import org.junit.BeforeClass;
 import org.junit.Test;
+import org.mockito.invocation.InvocationOnMock;
+import org.mockito.stubbing.Answer;
 import static org.junit.Assert.*;
+import static org.mockito.Mockito.*;
 
+import org.apache.spark.network.util.ByteArrayWritableChannel;
 import org.apache.spark.network.util.MapConfigProvider;
 import org.apache.spark.network.util.TransportConf;
 
@@ -104,4 +115,78 @@ public class AuthEngineSuite {
       challenge.cipher, challenge.keyLength, challenge.nonce, badChallenge));
   }
 
+  @Test
+  public void testEncryptedMessage() throws Exception {
+    AuthEngine client = new AuthEngine("appId", "secret", conf);
+    AuthEngine server = new AuthEngine("appId", "secret", conf);
+    try {
+      ClientChallenge clientChallenge = client.challenge();
+      ServerResponse serverResponse = server.respond(clientChallenge);
+      client.validate(serverResponse);
+
+      TransportCipher cipher = server.sessionCipher();
+      TransportCipher.EncryptionHandler handler = new 
TransportCipher.EncryptionHandler(cipher);
+
+      byte[] data = new byte[TransportCipher.STREAM_BUFFER_SIZE + 1];
+      new Random().nextBytes(data);
+      ByteBuf buf = Unpooled.wrappedBuffer(data);
+
+      ByteArrayWritableChannel channel = new 
ByteArrayWritableChannel(data.length);
+      TransportCipher.EncryptedMessage emsg = 
handler.createEncryptedMessage(buf);
+      while (emsg.transfered() < emsg.count()) {
+        emsg.transferTo(channel, emsg.transfered());
+      }
+      assertEquals(data.length, channel.length());
+    } finally {
+      client.close();
+      server.close();
+    }
+  }
+
+  @Test
+  public void testEncryptedMessageWhenTransferringZeroBytes() throws Exception 
{
+    AuthEngine client = new AuthEngine("appId", "secret", conf);
+    AuthEngine server = new AuthEngine("appId", "secret", conf);
+    try {
+      ClientChallenge clientChallenge = client.challenge();
+      ServerResponse serverResponse = server.respond(clientChallenge);
+      client.validate(serverResponse);
+
+      TransportCipher cipher = server.sessionCipher();
+      TransportCipher.EncryptionHandler handler = new 
TransportCipher.EncryptionHandler(cipher);
+
+      int testDataLength = 4;
+      FileRegion region = mock(FileRegion.class);
+      when(region.count()).thenReturn((long) testDataLength);
+      // Make `region.transferTo` do nothing in first call and transfer 4 
bytes in the second one.
+      when(region.transferTo(any(), anyLong())).thenAnswer(new Answer<Long>() {
+
+        private boolean firstTime = true;
+
+        @Override
+        public Long answer(InvocationOnMock invocationOnMock) throws Throwable 
{
+          if (firstTime) {
+            firstTime = false;
+            return 0L;
+          } else {
+            WritableByteChannel channel =
+              invocationOnMock.getArgumentAt(0, WritableByteChannel.class);
+            channel.write(ByteBuffer.wrap(new byte[testDataLength]));
+            return (long) testDataLength;
+          }
+        }
+      });
+
+      TransportCipher.EncryptedMessage emsg = 
handler.createEncryptedMessage(region);
+      ByteArrayWritableChannel channel = new 
ByteArrayWritableChannel(testDataLength);
+      // "transferTo" should act correctly when the underlying FileRegion 
transfers 0 bytes.
+      assertEquals(0L, emsg.transferTo(channel, emsg.transfered()));
+      assertEquals(testDataLength, emsg.transferTo(channel, 
emsg.transfered()));
+      assertEquals(emsg.transfered(), emsg.count());
+      assertEquals(4, channel.length());
+    } finally {
+      client.close();
+      server.close();
+    }
+  }
 }
diff --git 
a/common/network-common/src/test/java/org/apache/spark/network/crypto/AuthIntegrationSuite.java
 
b/common/network-common/src/test/java/org/apache/spark/network/crypto/AuthIntegrationSuite.java
index 8751944..73418a9 100644
--- 
a/common/network-common/src/test/java/org/apache/spark/network/crypto/AuthIntegrationSuite.java
+++ 
b/common/network-common/src/test/java/org/apache/spark/network/crypto/AuthIntegrationSuite.java
@@ -124,6 +124,42 @@ public class AuthIntegrationSuite {
     }
   }
 
+  @Test
+  public void testLargeMessageEncryption() throws Exception {
+    // Use a big length to create a message that cannot be put into the 
encryption buffer completely
+    final int testErrorMessageLength = TransportCipher.STREAM_BUFFER_SIZE;
+    ctx = new AuthTestCtx(new RpcHandler() {
+      @Override
+      public void receive(
+          TransportClient client,
+          ByteBuffer message,
+          RpcResponseCallback callback) {
+        char[] longMessage = new char[testErrorMessageLength];
+        Arrays.fill(longMessage, 'D');
+        callback.onFailure(new RuntimeException(new String(longMessage)));
+      }
+
+      @Override
+      public StreamManager getStreamManager() {
+        return null;
+      }
+    });
+    ctx.createServer("secret");
+    ctx.createClient("secret");
+
+    try {
+      ctx.client.sendRpcSync(JavaUtils.stringToBytes("Ping"), 5000);
+      fail("Should have failed unencrypted RPC.");
+    } catch (Exception e) {
+      assertTrue(ctx.authRpcHandler.doDelegate);
+      assertTrue(e.getMessage() + " is not an expected error", 
e.getMessage().contains("DDDDD"));
+      // Verify we receive the complete error message
+      int messageStart = e.getMessage().indexOf("DDDDD");
+      int messageEnd = e.getMessage().lastIndexOf("DDDDD") + 5;
+      assertEquals(testErrorMessageLength, messageEnd - messageStart);
+    }
+  }
+
   private class AuthTestCtx {
 
     private final String appId = "testAppId";
@@ -136,10 +172,7 @@ public class AuthIntegrationSuite {
     volatile AuthRpcHandler authRpcHandler;
 
     AuthTestCtx() throws Exception {
-      Map<String, String> testConf = 
ImmutableMap.of("spark.network.crypto.enabled", "true");
-      this.conf = new TransportConf("rpc", new MapConfigProvider(testConf));
-
-      RpcHandler rpcHandler = new RpcHandler() {
+      this(new RpcHandler() {
         @Override
         public void receive(
             TransportClient client,
@@ -153,8 +186,12 @@ public class AuthIntegrationSuite {
         public StreamManager getStreamManager() {
           return null;
         }
-      };
+      });
+    }
 
+    AuthTestCtx(RpcHandler rpcHandler) throws Exception {
+      Map<String, String> testConf = 
ImmutableMap.of("spark.network.crypto.enabled", "true");
+      this.conf = new TransportConf("rpc", new MapConfigProvider(testConf));
       this.ctx = new TransportContext(conf, rpcHandler);
     }
 


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to