Repository: spark
Updated Branches:
  refs/heads/master 5ddf69470 -> 4f15d94cf


[SPARK-13331] AES support for over-the-wire encryption

## What changes were proposed in this pull request?

DIGEST-MD5 mechanism is used for SASL authentication and secure communication. 
DIGEST-MD5 mechanism supports 3DES, DES, and RC4 ciphers. However, 3DES, DES 
and RC4 are slow relatively.

AES provide better performance and security by design and is a replacement for 
3DES according to NIST. Apache Common Crypto is a cryptographic library 
optimized with AES-NI, this patch employ Apache Common Crypto as enc/dec 
backend for SASL authentication and secure channel to improve spark RPC.
## How was this patch tested?

Unit tests and Integration test.

Author: Junjie Chen <junjie.j.c...@intel.com>

Closes #15172 from cjjnjust/shuffle_rpc_encrypt.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/4f15d94c
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/4f15d94c
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/4f15d94c

Branch: refs/heads/master
Commit: 4f15d94cfec86130f8dab28ae2e228ded8124020
Parents: 5ddf694
Author: Junjie Chen <junjie.j.c...@intel.com>
Authored: Fri Nov 11 10:37:58 2016 -0800
Committer: Marcelo Vanzin <van...@cloudera.com>
Committed: Fri Nov 11 10:37:58 2016 -0800

----------------------------------------------------------------------
 common/network-common/pom.xml                   |   4 +
 .../spark/network/sasl/SaslClientBootstrap.java |  23 +-
 .../spark/network/sasl/SaslRpcHandler.java      | 101 +++++--
 .../spark/network/sasl/aes/AesCipher.java       | 294 +++++++++++++++++++
 .../network/sasl/aes/AesConfigMessage.java      | 101 +++++++
 .../network/util/ByteArrayReadableChannel.java  |  62 ++++
 .../spark/network/util/TransportConf.java       |  22 ++
 .../spark/network/sasl/SparkSaslSuite.java      |  93 +++++-
 docs/configuration.md                           |  26 ++
 9 files changed, 689 insertions(+), 37 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/4f15d94c/common/network-common/pom.xml
----------------------------------------------------------------------
diff --git a/common/network-common/pom.xml b/common/network-common/pom.xml
index fcefe64..ca99fa8 100644
--- a/common/network-common/pom.xml
+++ b/common/network-common/pom.xml
@@ -76,6 +76,10 @@
       <artifactId>guava</artifactId>
       <scope>compile</scope>
     </dependency>
+    <dependency>
+      <groupId>org.apache.commons</groupId>
+      <artifactId>commons-crypto</artifactId>
+    </dependency>
 
     <!-- Test dependencies -->
     <dependency>

http://git-wip-us.apache.org/repos/asf/spark/blob/4f15d94c/common/network-common/src/main/java/org/apache/spark/network/sasl/SaslClientBootstrap.java
----------------------------------------------------------------------
diff --git 
a/common/network-common/src/main/java/org/apache/spark/network/sasl/SaslClientBootstrap.java
 
b/common/network-common/src/main/java/org/apache/spark/network/sasl/SaslClientBootstrap.java
index 9e5c616..a1bb453 100644
--- 
a/common/network-common/src/main/java/org/apache/spark/network/sasl/SaslClientBootstrap.java
+++ 
b/common/network-common/src/main/java/org/apache/spark/network/sasl/SaslClientBootstrap.java
@@ -30,6 +30,8 @@ import org.slf4j.LoggerFactory;
 
 import org.apache.spark.network.client.TransportClient;
 import org.apache.spark.network.client.TransportClientBootstrap;
+import org.apache.spark.network.sasl.aes.AesCipher;
+import org.apache.spark.network.sasl.aes.AesConfigMessage;
 import org.apache.spark.network.util.JavaUtils;
 import org.apache.spark.network.util.TransportConf;
 
@@ -88,9 +90,26 @@ public class SaslClientBootstrap implements 
TransportClientBootstrap {
           throw new RuntimeException(
             new SaslException("Encryption requests by negotiated non-encrypted 
connection."));
         }
-        SaslEncryption.addToChannel(channel, saslClient, 
conf.maxSaslEncryptedBlockSize());
+
+        if (conf.aesEncryptionEnabled()) {
+          // Generate a request config message to send to server.
+          AesConfigMessage configMessage = AesCipher.createConfigMessage(conf);
+          ByteBuffer buf = configMessage.encodeMessage();
+
+          // Encrypted the config message.
+          byte[] toEncrypt = JavaUtils.bufferToArray(buf);
+          ByteBuffer encrypted = ByteBuffer.wrap(saslClient.wrap(toEncrypt, 0, 
toEncrypt.length));
+
+          client.sendRpcSync(encrypted, conf.saslRTTimeoutMs());
+          AesCipher cipher = new AesCipher(configMessage, conf);
+          logger.info("Enabling AES cipher for client channel {}", client);
+          cipher.addToChannel(channel);
+          saslClient.dispose();
+        } else {
+          SaslEncryption.addToChannel(channel, saslClient, 
conf.maxSaslEncryptedBlockSize());
+        }
         saslClient = null;
-        logger.debug("Channel {} configured for SASL encryption.", client);
+        logger.debug("Channel {} configured for encryption.", client);
       }
     } catch (IOException ioe) {
       throw new RuntimeException(ioe);

http://git-wip-us.apache.org/repos/asf/spark/blob/4f15d94c/common/network-common/src/main/java/org/apache/spark/network/sasl/SaslRpcHandler.java
----------------------------------------------------------------------
diff --git 
a/common/network-common/src/main/java/org/apache/spark/network/sasl/SaslRpcHandler.java
 
b/common/network-common/src/main/java/org/apache/spark/network/sasl/SaslRpcHandler.java
index c41f5b6..b2f3ef2 100644
--- 
a/common/network-common/src/main/java/org/apache/spark/network/sasl/SaslRpcHandler.java
+++ 
b/common/network-common/src/main/java/org/apache/spark/network/sasl/SaslRpcHandler.java
@@ -29,6 +29,8 @@ import org.slf4j.LoggerFactory;
 
 import org.apache.spark.network.client.RpcResponseCallback;
 import org.apache.spark.network.client.TransportClient;
+import org.apache.spark.network.sasl.aes.AesCipher;
+import org.apache.spark.network.sasl.aes.AesConfigMessage;
 import org.apache.spark.network.server.RpcHandler;
 import org.apache.spark.network.server.StreamManager;
 import org.apache.spark.network.util.JavaUtils;
@@ -59,6 +61,7 @@ class SaslRpcHandler extends RpcHandler {
 
   private SparkSaslServer saslServer;
   private boolean isComplete;
+  private boolean isAuthenticated;
 
   SaslRpcHandler(
       TransportConf conf,
@@ -71,6 +74,7 @@ class SaslRpcHandler extends RpcHandler {
     this.secretKeyHolder = secretKeyHolder;
     this.saslServer = null;
     this.isComplete = false;
+    this.isAuthenticated = false;
   }
 
   @Override
@@ -80,30 +84,31 @@ class SaslRpcHandler extends RpcHandler {
       delegate.receive(client, message, callback);
       return;
     }
+    if (saslServer == null || !saslServer.isComplete()) {
+      ByteBuf nettyBuf = Unpooled.wrappedBuffer(message);
+      SaslMessage saslMessage;
+      try {
+        saslMessage = SaslMessage.decode(nettyBuf);
+      } finally {
+        nettyBuf.release();
+      }
 
-    ByteBuf nettyBuf = Unpooled.wrappedBuffer(message);
-    SaslMessage saslMessage;
-    try {
-      saslMessage = SaslMessage.decode(nettyBuf);
-    } finally {
-      nettyBuf.release();
-    }
-
-    if (saslServer == null) {
-      // First message in the handshake, setup the necessary state.
-      client.setClientId(saslMessage.appId);
-      saslServer = new SparkSaslServer(saslMessage.appId, secretKeyHolder,
-        conf.saslServerAlwaysEncrypt());
-    }
+      if (saslServer == null) {
+        // First message in the handshake, setup the necessary state.
+        client.setClientId(saslMessage.appId);
+        saslServer = new SparkSaslServer(saslMessage.appId, secretKeyHolder,
+          conf.saslServerAlwaysEncrypt());
+      }
 
-    byte[] response;
-    try {
-      response = saslServer.response(JavaUtils.bufferToArray(
-        saslMessage.body().nioByteBuffer()));
-    } catch (IOException ioe) {
-      throw new RuntimeException(ioe);
+      byte[] response;
+      try {
+        response = saslServer.response(JavaUtils.bufferToArray(
+          saslMessage.body().nioByteBuffer()));
+      } catch (IOException ioe) {
+        throw new RuntimeException(ioe);
+      }
+      callback.onSuccess(ByteBuffer.wrap(response));
     }
-    callback.onSuccess(ByteBuffer.wrap(response));
 
     // Setup encryption after the SASL response is sent, otherwise the client 
can't parse the
     // response. It's ok to change the channel pipeline here since we are 
processing an incoming
@@ -111,15 +116,42 @@ class SaslRpcHandler extends RpcHandler {
     // method returns. This assumes that the code ensures, through other 
means, that no outbound
     // messages are being written to the channel while negotiation is still 
going on.
     if (saslServer.isComplete()) {
-      logger.debug("SASL authentication successful for channel {}", client);
-      isComplete = true;
-      if 
(SparkSaslServer.QOP_AUTH_CONF.equals(saslServer.getNegotiatedProperty(Sasl.QOP)))
 {
+      if 
(!SparkSaslServer.QOP_AUTH_CONF.equals(saslServer.getNegotiatedProperty(Sasl.QOP)))
 {
+        logger.debug("SASL authentication successful for channel {}", client);
+        complete(true);
+        return;
+      }
+
+      if (!conf.aesEncryptionEnabled()) {
         logger.debug("Enabling encryption for channel {}", client);
         SaslEncryption.addToChannel(channel, saslServer, 
conf.maxSaslEncryptedBlockSize());
-        saslServer = null;
-      } else {
-        saslServer.dispose();
-        saslServer = null;
+        complete(false);
+        return;
+      }
+
+      // Extra negotiation should happen after authentication, so return 
directly while
+      // processing authenticate.
+      if (!isAuthenticated) {
+        logger.debug("SASL authentication successful for channel {}", client);
+        isAuthenticated = true;
+        return;
+      }
+
+      // Create AES cipher when it is authenticated
+      try {
+        byte[] encrypted = JavaUtils.bufferToArray(message);
+        ByteBuffer decrypted = ByteBuffer.wrap(saslServer.unwrap(encrypted, 0 
, encrypted.length));
+
+        AesConfigMessage configMessage = 
AesConfigMessage.decodeMessage(decrypted);
+        AesCipher cipher = new AesCipher(configMessage, conf);
+
+        // Send response back to client to confirm that server accept config.
+        callback.onSuccess(JavaUtils.stringToBytes(AesCipher.TRANSFORM));
+        logger.info("Enabling AES cipher for Server channel {}", client);
+        cipher.addToChannel(channel);
+        complete(true);
+      } catch (IOException ioe) {
+        throw new RuntimeException(ioe);
       }
     }
   }
@@ -155,4 +187,17 @@ class SaslRpcHandler extends RpcHandler {
     delegate.exceptionCaught(cause, client);
   }
 
+  private void complete(boolean dispose) {
+    if (dispose) {
+      try {
+        saslServer.dispose();
+      } catch (RuntimeException e) {
+        logger.error("Error while disposing SASL server", e);
+      }
+    }
+
+    saslServer = null;
+    isComplete = true;
+  }
+
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/4f15d94c/common/network-common/src/main/java/org/apache/spark/network/sasl/aes/AesCipher.java
----------------------------------------------------------------------
diff --git 
a/common/network-common/src/main/java/org/apache/spark/network/sasl/aes/AesCipher.java
 
b/common/network-common/src/main/java/org/apache/spark/network/sasl/aes/AesCipher.java
new file mode 100644
index 0000000..78034a6
--- /dev/null
+++ 
b/common/network-common/src/main/java/org/apache/spark/network/sasl/aes/AesCipher.java
@@ -0,0 +1,294 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.network.sasl.aes;
+
+import java.io.IOException;
+import java.nio.ByteBuffer;
+import java.nio.channels.ReadableByteChannel;
+import java.nio.channels.WritableByteChannel;
+import java.util.Properties;
+import javax.crypto.spec.SecretKeySpec;
+import javax.crypto.spec.IvParameterSpec;
+
+import com.google.common.base.Preconditions;
+import com.google.common.base.Throwables;
+import io.netty.buffer.ByteBuf;
+import io.netty.buffer.Unpooled;
+import io.netty.channel.*;
+import io.netty.util.AbstractReferenceCounted;
+import org.apache.commons.crypto.cipher.CryptoCipherFactory;
+import org.apache.commons.crypto.random.CryptoRandom;
+import org.apache.commons.crypto.random.CryptoRandomFactory;
+import org.apache.commons.crypto.stream.CryptoInputStream;
+import org.apache.commons.crypto.stream.CryptoOutputStream;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import org.apache.spark.network.util.ByteArrayReadableChannel;
+import org.apache.spark.network.util.ByteArrayWritableChannel;
+import org.apache.spark.network.util.TransportConf;
+
+/**
+ * AES cipher for encryption and decryption.
+ */
+public class AesCipher {
+  private static final Logger logger = 
LoggerFactory.getLogger(AesCipher.class);
+  public static final String ENCRYPTION_HANDLER_NAME = "AesEncryption";
+  public static final String DECRYPTION_HANDLER_NAME = "AesDecryption";
+  public static final int STREAM_BUFFER_SIZE = 1024 * 32;
+  public static final String TRANSFORM = "AES/CTR/NoPadding";
+
+  private final SecretKeySpec inKeySpec;
+  private final IvParameterSpec inIvSpec;
+  private final SecretKeySpec outKeySpec;
+  private final IvParameterSpec outIvSpec;
+  private final Properties properties;
+
+  public AesCipher(AesConfigMessage configMessage, TransportConf conf) throws 
IOException  {
+    this.properties = CryptoStreamUtils.toCryptoConf(conf);
+    this.inKeySpec = new SecretKeySpec(configMessage.inKey, "AES");
+    this.inIvSpec = new IvParameterSpec(configMessage.inIv);
+    this.outKeySpec = new SecretKeySpec(configMessage.outKey, "AES");
+    this.outIvSpec = new IvParameterSpec(configMessage.outIv);
+  }
+
+  /**
+   * Create AES crypto output stream
+   * @param ch The underlying channel to write out.
+   * @return Return output crypto stream for encryption.
+   * @throws IOException
+   */
+  private CryptoOutputStream createOutputStream(WritableByteChannel ch) throws 
IOException {
+    return new CryptoOutputStream(TRANSFORM, properties, ch, outKeySpec, 
outIvSpec);
+  }
+
+  /**
+   * Create AES crypto input stream
+   * @param ch The underlying channel used to read data.
+   * @return Return input crypto stream for decryption.
+   * @throws IOException
+   */
+  private CryptoInputStream createInputStream(ReadableByteChannel ch) throws 
IOException {
+    return new CryptoInputStream(TRANSFORM, properties, ch, inKeySpec, 
inIvSpec);
+  }
+
+  /**
+   * Add handlers to channel
+   * @param ch the channel for adding handlers
+   * @throws IOException
+   */
+  public void addToChannel(Channel ch) throws IOException {
+    ch.pipeline()
+      .addFirst(ENCRYPTION_HANDLER_NAME, new AesEncryptHandler(this))
+      .addFirst(DECRYPTION_HANDLER_NAME, new AesDecryptHandler(this));
+  }
+
+  /**
+   * Create the configuration message
+   * @param conf is the local transport configuration.
+   * @return Config message for sending.
+   */
+  public static AesConfigMessage createConfigMessage(TransportConf conf) {
+    int keySize = conf.aesCipherKeySize();
+    Properties properties = CryptoStreamUtils.toCryptoConf(conf);
+
+    try {
+      int paramLen = CryptoCipherFactory.getCryptoCipher(AesCipher.TRANSFORM, 
properties)
+        .getBlockSize();
+      byte[] inKey = new byte[keySize];
+      byte[] outKey = new byte[keySize];
+      byte[] inIv = new byte[paramLen];
+      byte[] outIv = new byte[paramLen];
+
+      CryptoRandom random = CryptoRandomFactory.getCryptoRandom(properties);
+      random.nextBytes(inKey);
+      random.nextBytes(outKey);
+      random.nextBytes(inIv);
+      random.nextBytes(outIv);
+
+      return new AesConfigMessage(inKey, inIv, outKey, outIv);
+    } catch (Exception e) {
+      logger.error("AES config error", e);
+      throw Throwables.propagate(e);
+    }
+  }
+
+  /**
+   * CryptoStreamUtils is used to convert config from TransportConf to AES 
Crypto config.
+   */
+  private static class CryptoStreamUtils {
+    public static Properties toCryptoConf(TransportConf conf) {
+      Properties props = new Properties();
+      if (conf.aesCipherClass() != null) {
+        props.setProperty(CryptoCipherFactory.CLASSES_KEY, 
conf.aesCipherClass());
+      }
+      return props;
+    }
+  }
+
+  private static class AesEncryptHandler extends ChannelOutboundHandlerAdapter 
{
+    private final ByteArrayWritableChannel byteChannel;
+    private final CryptoOutputStream cos;
+
+    AesEncryptHandler(AesCipher cipher) throws IOException {
+      byteChannel = new ByteArrayWritableChannel(AesCipher.STREAM_BUFFER_SIZE);
+      cos = cipher.createOutputStream(byteChannel);
+    }
+
+    @Override
+    public void write(ChannelHandlerContext ctx, Object msg, ChannelPromise 
promise)
+      throws Exception {
+      ctx.write(new EncryptedMessage(cos, msg, byteChannel), promise);
+    }
+
+    @Override
+    public void close(ChannelHandlerContext ctx, ChannelPromise promise) 
throws Exception {
+      try {
+        cos.close();
+      } finally {
+        super.close(ctx, promise);
+      }
+    }
+  }
+
+  private static class AesDecryptHandler extends ChannelInboundHandlerAdapter {
+    private final CryptoInputStream cis;
+    private final ByteArrayReadableChannel byteChannel;
+
+    AesDecryptHandler(AesCipher cipher) throws IOException {
+      byteChannel = new ByteArrayReadableChannel();
+      cis = cipher.createInputStream(byteChannel);
+    }
+
+    @Override
+    public void channelRead(ChannelHandlerContext ctx, Object data) throws 
Exception {
+      byteChannel.feedData((ByteBuf) data);
+
+      byte[] decryptedData = new byte[byteChannel.readableBytes()];
+      int offset = 0;
+      while (offset < decryptedData.length) {
+        offset += cis.read(decryptedData, offset, decryptedData.length - 
offset);
+      }
+
+      ctx.fireChannelRead(Unpooled.wrappedBuffer(decryptedData, 0, 
decryptedData.length));
+    }
+
+    @Override
+    public void channelInactive(ChannelHandlerContext ctx) throws Exception {
+      try {
+        cis.close();
+      } finally {
+        super.channelInactive(ctx);
+      }
+    }
+  }
+
+  private static class EncryptedMessage extends AbstractReferenceCounted 
implements FileRegion {
+    private final boolean isByteBuf;
+    private final ByteBuf buf;
+    private final FileRegion region;
+    private long transferred;
+    private CryptoOutputStream cos;
+
+    // Due to streaming issue CRYPTO-125: 
https://issues.apache.org/jira/browse/CRYPTO-125, it has
+    // to utilize two helper ByteArrayWritableChannel for streaming. One is 
used to receive raw data
+    // from upper handler, another is used to store encrypted data.
+    private ByteArrayWritableChannel byteEncChannel;
+    private ByteArrayWritableChannel byteRawChannel;
+
+    private ByteBuffer currentEncrypted;
+
+    EncryptedMessage(CryptoOutputStream cos, Object msg, 
ByteArrayWritableChannel ch) {
+      Preconditions.checkArgument(msg instanceof ByteBuf || msg instanceof 
FileRegion,
+        "Unrecognized message type: %s", msg.getClass().getName());
+      this.isByteBuf = msg instanceof ByteBuf;
+      this.buf = isByteBuf ? (ByteBuf) msg : null;
+      this.region = isByteBuf ? null : (FileRegion) msg;
+      this.transferred = 0;
+      this.byteRawChannel = new 
ByteArrayWritableChannel(AesCipher.STREAM_BUFFER_SIZE);
+      this.cos = cos;
+      this.byteEncChannel = ch;
+    }
+
+    @Override
+    public long count() {
+      return isByteBuf ? buf.readableBytes() : region.count();
+    }
+
+    @Override
+    public long position() {
+      return 0;
+    }
+
+    @Override
+    public long transfered() {
+      return transferred;
+    }
+
+    @Override
+    public long transferTo(WritableByteChannel target, long position) throws 
IOException {
+      Preconditions.checkArgument(position == transfered(), "Invalid 
position.");
+
+      do {
+        if (currentEncrypted == null) {
+          encryptMore();
+        }
+
+        int bytesWritten = currentEncrypted.remaining();
+        target.write(currentEncrypted);
+        bytesWritten -= currentEncrypted.remaining();
+        transferred += bytesWritten;
+        if (!currentEncrypted.hasRemaining()) {
+          currentEncrypted = null;
+          byteEncChannel.reset();
+        }
+      } while (transferred < count());
+
+      return transferred;
+    }
+
+    private void encryptMore() throws IOException {
+      byteRawChannel.reset();
+
+      if (isByteBuf) {
+        int copied = byteRawChannel.write(buf.nioBuffer());
+        buf.skipBytes(copied);
+      } else {
+        region.transferTo(byteRawChannel, region.transfered());
+      }
+      cos.write(byteRawChannel.getData(), 0, byteRawChannel.length());
+      cos.flush();
+
+      currentEncrypted = ByteBuffer.wrap(byteEncChannel.getData(),
+        0, byteEncChannel.length());
+    }
+
+    @Override
+    protected void deallocate() {
+      byteRawChannel.reset();
+      byteEncChannel.reset();
+      if (region != null) {
+        region.release();
+      }
+      if (buf != null) {
+        buf.release();
+      }
+    }
+  }
+
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/4f15d94c/common/network-common/src/main/java/org/apache/spark/network/sasl/aes/AesConfigMessage.java
----------------------------------------------------------------------
diff --git 
a/common/network-common/src/main/java/org/apache/spark/network/sasl/aes/AesConfigMessage.java
 
b/common/network-common/src/main/java/org/apache/spark/network/sasl/aes/AesConfigMessage.java
new file mode 100644
index 0000000..3ef6f74
--- /dev/null
+++ 
b/common/network-common/src/main/java/org/apache/spark/network/sasl/aes/AesConfigMessage.java
@@ -0,0 +1,101 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.network.sasl.aes;
+
+import java.nio.ByteBuffer;
+
+import io.netty.buffer.ByteBuf;
+import io.netty.buffer.Unpooled;
+
+import org.apache.spark.network.protocol.Encodable;
+import org.apache.spark.network.protocol.Encoders;
+
+/**
+ * The AES cipher options for encryption negotiation.
+ */
+public class AesConfigMessage implements Encodable {
+  /** Serialization tag used to catch incorrect payloads. */
+  private static final byte TAG_BYTE = (byte) 0xEB;
+
+  public byte[] inKey;
+  public byte[] outKey;
+  public byte[] inIv;
+  public byte[] outIv;
+
+  public AesConfigMessage(byte[] inKey, byte[] inIv, byte[] outKey, byte[] 
outIv) {
+    if (inKey == null || inIv == null || outKey == null || outIv == null) {
+      throw new IllegalArgumentException("Cipher Key or IV must not be null!");
+    }
+
+    this.inKey = inKey;
+    this.inIv = inIv;
+    this.outKey = outKey;
+    this.outIv = outIv;
+  }
+
+  @Override
+  public int encodedLength() {
+    return 1 +
+      Encoders.ByteArrays.encodedLength(inKey) + 
Encoders.ByteArrays.encodedLength(outKey) +
+      Encoders.ByteArrays.encodedLength(inIv) + 
Encoders.ByteArrays.encodedLength(outIv);
+  }
+
+  @Override
+  public void encode(ByteBuf buf) {
+    buf.writeByte(TAG_BYTE);
+    Encoders.ByteArrays.encode(buf, inKey);
+    Encoders.ByteArrays.encode(buf, inIv);
+    Encoders.ByteArrays.encode(buf, outKey);
+    Encoders.ByteArrays.encode(buf, outIv);
+  }
+
+  /**
+   * Encode the config message.
+   * @return ByteBuffer which contains encoded config message.
+   */
+  public ByteBuffer encodeMessage(){
+    ByteBuffer buf = ByteBuffer.allocate(encodedLength());
+
+    ByteBuf wrappedBuf = Unpooled.wrappedBuffer(buf);
+    wrappedBuf.clear();
+    encode(wrappedBuf);
+
+    return buf;
+  }
+
+  /**
+   * Decode the config message from buffer
+   * @param buffer the buffer contain encoded config message
+   * @return config message
+   */
+  public static AesConfigMessage decodeMessage(ByteBuffer buffer) {
+    ByteBuf buf = Unpooled.wrappedBuffer(buffer);
+
+    if (buf.readByte() != TAG_BYTE) {
+      throw new IllegalStateException("Expected AesConfigMessage, received 
something else"
+        + " (maybe your client does not have AES enabled?)");
+    }
+
+    byte[] outKey = Encoders.ByteArrays.decode(buf);
+    byte[] outIv = Encoders.ByteArrays.decode(buf);
+    byte[] inKey = Encoders.ByteArrays.decode(buf);
+    byte[] inIv = Encoders.ByteArrays.decode(buf);
+    return new AesConfigMessage(inKey, inIv, outKey, outIv);
+  }
+
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/4f15d94c/common/network-common/src/main/java/org/apache/spark/network/util/ByteArrayReadableChannel.java
----------------------------------------------------------------------
diff --git 
a/common/network-common/src/main/java/org/apache/spark/network/util/ByteArrayReadableChannel.java
 
b/common/network-common/src/main/java/org/apache/spark/network/util/ByteArrayReadableChannel.java
new file mode 100644
index 0000000..25d103d
--- /dev/null
+++ 
b/common/network-common/src/main/java/org/apache/spark/network/util/ByteArrayReadableChannel.java
@@ -0,0 +1,62 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.network.util;
+
+import java.io.IOException;
+import java.nio.ByteBuffer;
+import java.nio.channels.ReadableByteChannel;
+
+import io.netty.buffer.ByteBuf;
+
+public class ByteArrayReadableChannel implements ReadableByteChannel {
+  private ByteBuf data;
+
+  public int readableBytes() {
+    return data.readableBytes();
+  }
+
+  public void feedData(ByteBuf buf) {
+    data = buf;
+  }
+
+  @Override
+  public int read(ByteBuffer dst) throws IOException {
+    int totalRead = 0;
+    while (data.readableBytes() > 0 && dst.remaining() > 0) {
+      int bytesToRead = Math.min(data.readableBytes(), dst.remaining());
+      dst.put(data.readSlice(bytesToRead).nioBuffer());
+      totalRead += bytesToRead;
+    }
+
+    if (data.readableBytes() == 0) {
+      data.release();
+    }
+
+    return totalRead;
+  }
+
+  @Override
+  public void close() throws IOException {
+  }
+
+  @Override
+  public boolean isOpen() {
+    return true;
+  }
+
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/4f15d94c/common/network-common/src/main/java/org/apache/spark/network/util/TransportConf.java
----------------------------------------------------------------------
diff --git 
a/common/network-common/src/main/java/org/apache/spark/network/util/TransportConf.java
 
b/common/network-common/src/main/java/org/apache/spark/network/util/TransportConf.java
index 64eaba1..d0d0728 100644
--- 
a/common/network-common/src/main/java/org/apache/spark/network/util/TransportConf.java
+++ 
b/common/network-common/src/main/java/org/apache/spark/network/util/TransportConf.java
@@ -18,6 +18,7 @@
 package org.apache.spark.network.util;
 
 import com.google.common.primitives.Ints;
+import org.apache.commons.crypto.cipher.CryptoCipherFactory;
 
 /**
  * A central location that tracks all the settings we expose to users.
@@ -175,4 +176,25 @@ public class TransportConf {
     return conf.getBoolean("spark.network.sasl.serverAlwaysEncrypt", false);
   }
 
+  /**
+   * The trigger for enabling AES encryption.
+   */
+  public boolean aesEncryptionEnabled() {
+    return conf.getBoolean("spark.authenticate.encryption.aes.enabled", false);
+  }
+
+  /**
+   * The implementation class for crypto cipher
+   */
+  public String aesCipherClass() {
+    return conf.get("spark.authenticate.encryption.aes.cipher.class", null);
+  }
+
+  /**
+   * The bytes of AES cipher key which is effective when AES cipher is 
enabled. Notice that
+   * the length should be 16, 24 or 32 bytes.
+   */
+  public int aesCipherKeySize() {
+    return conf.getInt("spark.authenticate.encryption.aes.cipher.keySize", 16);
+  }
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/4f15d94c/common/network-common/src/test/java/org/apache/spark/network/sasl/SparkSaslSuite.java
----------------------------------------------------------------------
diff --git 
a/common/network-common/src/test/java/org/apache/spark/network/sasl/SparkSaslSuite.java
 
b/common/network-common/src/test/java/org/apache/spark/network/sasl/SparkSaslSuite.java
index 45cc03d..4e6146c 100644
--- 
a/common/network-common/src/test/java/org/apache/spark/network/sasl/SparkSaslSuite.java
+++ 
b/common/network-common/src/test/java/org/apache/spark/network/sasl/SparkSaslSuite.java
@@ -53,6 +53,7 @@ import org.apache.spark.network.client.ChunkReceivedCallback;
 import org.apache.spark.network.client.RpcResponseCallback;
 import org.apache.spark.network.client.TransportClient;
 import org.apache.spark.network.client.TransportClientBootstrap;
+import org.apache.spark.network.sasl.aes.AesCipher;
 import org.apache.spark.network.server.RpcHandler;
 import org.apache.spark.network.server.StreamManager;
 import org.apache.spark.network.server.TransportServer;
@@ -149,7 +150,7 @@ public class SparkSaslSuite {
       .when(rpcHandler)
       .receive(any(TransportClient.class), any(ByteBuffer.class), 
any(RpcResponseCallback.class));
 
-    SaslTestCtx ctx = new SaslTestCtx(rpcHandler, encrypt, false);
+    SaslTestCtx ctx = new SaslTestCtx(rpcHandler, encrypt, false, false);
     try {
       ByteBuffer response = 
ctx.client.sendRpcSync(JavaUtils.stringToBytes("Ping"),
         TimeUnit.SECONDS.toMillis(10));
@@ -275,7 +276,7 @@ public class SparkSaslSuite {
       new Random().nextBytes(data);
       Files.write(data, file);
 
-      ctx = new SaslTestCtx(rpcHandler, true, false);
+      ctx = new SaslTestCtx(rpcHandler, true, false, false);
 
       final CountDownLatch lock = new CountDownLatch(1);
 
@@ -317,7 +318,7 @@ public class SparkSaslSuite {
 
     SaslTestCtx ctx = null;
     try {
-      ctx = new SaslTestCtx(mock(RpcHandler.class), false, false);
+      ctx = new SaslTestCtx(mock(RpcHandler.class), false, false, false);
       fail("Should have failed to connect without encryption.");
     } catch (Exception e) {
       assertTrue(e.getCause() instanceof SaslException);
@@ -336,7 +337,7 @@ public class SparkSaslSuite {
     // able to understand RPCs sent to it and thus close the connection.
     SaslTestCtx ctx = null;
     try {
-      ctx = new SaslTestCtx(mock(RpcHandler.class), true, true);
+      ctx = new SaslTestCtx(mock(RpcHandler.class), true, true, false);
       ctx.client.sendRpcSync(JavaUtils.stringToBytes("Ping"),
         TimeUnit.SECONDS.toMillis(10));
       fail("Should have failed to send RPC to server.");
@@ -374,6 +375,69 @@ public class SparkSaslSuite {
     }
   }
 
+  @Test
+  public void testAesEncryption() throws Exception {
+    final AtomicReference<ManagedBuffer> response = new AtomicReference<>();
+    final File file = File.createTempFile("sasltest", ".txt");
+    SaslTestCtx ctx = null;
+    try {
+      final TransportConf conf = new TransportConf("rpc", new 
SystemPropertyConfigProvider());
+      final TransportConf spyConf = spy(conf);
+      doReturn(true).when(spyConf).aesEncryptionEnabled();
+
+      StreamManager sm = mock(StreamManager.class);
+      when(sm.getChunk(anyLong(), anyInt())).thenAnswer(new 
Answer<ManagedBuffer>() {
+        @Override
+        public ManagedBuffer answer(InvocationOnMock invocation) {
+          return new FileSegmentManagedBuffer(spyConf, file, 0, file.length());
+        }
+      });
+
+      RpcHandler rpcHandler = mock(RpcHandler.class);
+      when(rpcHandler.getStreamManager()).thenReturn(sm);
+
+      byte[] data = new byte[256 * 1024 * 1024];
+      new Random().nextBytes(data);
+      Files.write(data, file);
+
+      ctx = new SaslTestCtx(rpcHandler, true, false, true);
+
+      final Object lock = new Object();
+
+      ChunkReceivedCallback callback = mock(ChunkReceivedCallback.class);
+      doAnswer(new Answer<Void>() {
+        @Override
+        public Void answer(InvocationOnMock invocation) {
+          response.set((ManagedBuffer) invocation.getArguments()[1]);
+          response.get().retain();
+          synchronized (lock) {
+            lock.notifyAll();
+          }
+          return null;
+        }
+      }).when(callback).onSuccess(anyInt(), any(ManagedBuffer.class));
+
+      synchronized (lock) {
+        ctx.client.fetchChunk(0, 0, callback);
+        lock.wait(10 * 1000);
+      }
+
+      verify(callback, times(1)).onSuccess(anyInt(), any(ManagedBuffer.class));
+      verify(callback, never()).onFailure(anyInt(), any(Throwable.class));
+
+      byte[] received = 
ByteStreams.toByteArray(response.get().createInputStream());
+      assertTrue(Arrays.equals(data, received));
+    } finally {
+      file.delete();
+      if (ctx != null) {
+        ctx.close();
+      }
+      if (response.get() != null) {
+        response.get().release();
+      }
+    }
+  }
+
   private static class SaslTestCtx {
 
     final TransportClient client;
@@ -386,18 +450,28 @@ public class SparkSaslSuite {
     SaslTestCtx(
         RpcHandler rpcHandler,
         boolean encrypt,
-        boolean disableClientEncryption)
+        boolean disableClientEncryption,
+        boolean aesEnable)
       throws Exception {
 
       TransportConf conf = new TransportConf("shuffle", new 
SystemPropertyConfigProvider());
 
+      if (aesEnable) {
+        conf = spy(conf);
+        doReturn(true).when(conf).aesEncryptionEnabled();
+      }
+
       SecretKeyHolder keyHolder = mock(SecretKeyHolder.class);
       when(keyHolder.getSaslUser(anyString())).thenReturn("user");
       when(keyHolder.getSecretKey(anyString())).thenReturn("secret");
 
       TransportContext ctx = new TransportContext(conf, rpcHandler);
 
-      this.checker = new EncryptionCheckerBootstrap();
+      String encryptHandlerName = aesEnable ? 
AesCipher.ENCRYPTION_HANDLER_NAME :
+        SaslEncryption.ENCRYPTION_HANDLER_NAME;
+
+      this.checker = new EncryptionCheckerBootstrap(encryptHandlerName);
+
       this.server = ctx.createServer(Arrays.asList(new 
SaslServerBootstrap(conf, keyHolder),
         checker));
 
@@ -437,13 +511,18 @@ public class SparkSaslSuite {
     implements TransportServerBootstrap {
 
     boolean foundEncryptionHandler;
+    String encryptHandlerName;
+
+    public EncryptionCheckerBootstrap(String encryptHandlerName) {
+      this.encryptHandlerName = encryptHandlerName;
+    }
 
     @Override
     public void write(ChannelHandlerContext ctx, Object msg, ChannelPromise 
promise)
       throws Exception {
       if (!foundEncryptionHandler) {
         foundEncryptionHandler =
-          ctx.channel().pipeline().get(SaslEncryption.ENCRYPTION_HANDLER_NAME) 
!= null;
+          ctx.channel().pipeline().get(encryptHandlerName) != null;
       }
       ctx.write(msg, promise);
     }

http://git-wip-us.apache.org/repos/asf/spark/blob/4f15d94c/docs/configuration.md
----------------------------------------------------------------------
diff --git a/docs/configuration.md b/docs/configuration.md
index d0acd94..41c1778 100644
--- a/docs/configuration.md
+++ b/docs/configuration.md
@@ -1530,6 +1530,32 @@ Apart from these, the following properties are also 
available, and may be useful
   </td>
 </tr>
 <tr>
+  <td><code>spark.authenticate.encryption.aes.enabled</code></td>
+  <td>false</td>
+  <td>
+    Enable AES for over-the-wire encryption
+  </td>
+</tr>
+<tr>
+  <td><code>spark.authenticate.encryption.aes.cipher.keySize</code></td>
+  <td>16</td>
+  <td>
+    The bytes of AES cipher key which is effective when AES cipher is enabled. 
AES
+    works with 16, 24 and 32 bytes keys.
+  </td>
+</tr>
+<tr>
+  <td><code>spark.authenticate.encryption.aes.cipher.class</code></td>
+  <td>null</td>
+  <td>
+    Specify the underlying implementation class of crypto cipher. Set null 
here to use default.
+    In order to use OpenSslCipher users should install openssl. Currently, 
there are two cipher
+    classes available in Commons Crypto library:
+        org.apache.commons.crypto.cipher.OpenSslCipher
+        org.apache.commons.crypto.cipher.JceCipher
+  </td>
+</tr>
+<tr>
   <td><code>spark.core.connection.ack.wait.timeout</code></td>
   <td>60s</td>
   <td>


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to