This is an automated email from the ASF dual-hosted git repository.

dongjoon pushed a commit to branch branch-2.4
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/branch-2.4 by this push:
     new fce9b2b  [SPARK-25535][CORE][BRANCH-2.4] Work around bad error 
handling in commons-crypto.
fce9b2b is described below

commit fce9b2bce647f7554cacd1245cf670ff938f84f7
Author: Marcelo Vanzin <van...@cloudera.com>
AuthorDate: Fri Apr 26 21:23:17 2019 -0700

    [SPARK-25535][CORE][BRANCH-2.4] Work around bad error handling in 
commons-crypto.
    
    The commons-crypto library does some questionable error handling internally,
    which can lead to JVM crashes if some call into native code fails and cleans
    up state it should not.
    
    While the library is not fixed, this change adds some workarounds in Spark 
code
    so that when an error is detected in the commons-crypto side, Spark avoids
    calling into the library further.
    
    Tested with existing and added unit tests.
    
    Closes #24476 from vanzin/SPARK-25535-2.4.
    
    Authored-by: Marcelo Vanzin <van...@cloudera.com>
    Signed-off-by: Dongjoon Hyun <dh...@apple.com>
---
 .../apache/spark/network/crypto/AuthEngine.java    |  95 +++++++++-----
 .../spark/network/crypto/TransportCipher.java      |  60 +++++++--
 .../spark/network/crypto/AuthEngineSuite.java      |  17 +++
 .../apache/spark/security/CryptoStreamUtils.scala  | 137 +++++++++++++++++++--
 .../spark/security/CryptoStreamUtilsSuite.scala    |  37 +++++-
 5 files changed, 295 insertions(+), 51 deletions(-)

diff --git 
a/common/network-common/src/main/java/org/apache/spark/network/crypto/AuthEngine.java
 
b/common/network-common/src/main/java/org/apache/spark/network/crypto/AuthEngine.java
index 056505e..64fdb32 100644
--- 
a/common/network-common/src/main/java/org/apache/spark/network/crypto/AuthEngine.java
+++ 
b/common/network-common/src/main/java/org/apache/spark/network/crypto/AuthEngine.java
@@ -159,15 +159,21 @@ class AuthEngine implements Closeable {
     // accurately report the errors when they happen.
     RuntimeException error = null;
     byte[] dummy = new byte[8];
-    try {
-      doCipherOp(encryptor, dummy, true);
-    } catch (Exception e) {
-      error = new RuntimeException(e);
+    if (encryptor != null) {
+      try {
+        doCipherOp(Cipher.ENCRYPT_MODE, dummy, true);
+      } catch (Exception e) {
+        error = new RuntimeException(e);
+      }
+      encryptor = null;
     }
-    try {
-      doCipherOp(decryptor, dummy, true);
-    } catch (Exception e) {
-      error = new RuntimeException(e);
+    if (decryptor != null) {
+      try {
+        doCipherOp(Cipher.DECRYPT_MODE, dummy, true);
+      } catch (Exception e) {
+        error = new RuntimeException(e);
+      }
+      decryptor = null;
     }
     random.close();
 
@@ -189,11 +195,11 @@ class AuthEngine implements Closeable {
   }
 
   private byte[] decrypt(byte[] in) throws GeneralSecurityException {
-    return doCipherOp(decryptor, in, false);
+    return doCipherOp(Cipher.DECRYPT_MODE, in, false);
   }
 
   private byte[] encrypt(byte[] in) throws GeneralSecurityException {
-    return doCipherOp(encryptor, in, false);
+    return doCipherOp(Cipher.ENCRYPT_MODE, in, false);
   }
 
   private void initializeForAuth(String cipher, byte[] nonce, SecretKeySpec 
key)
@@ -205,11 +211,13 @@ class AuthEngine implements Closeable {
     byte[] iv = new byte[conf.ivLength()];
     System.arraycopy(nonce, 0, iv, 0, Math.min(nonce.length, iv.length));
 
-    encryptor = CryptoCipherFactory.getCryptoCipher(cipher, cryptoConf);
-    encryptor.init(Cipher.ENCRYPT_MODE, key, new IvParameterSpec(iv));
+    CryptoCipher _encryptor = CryptoCipherFactory.getCryptoCipher(cipher, 
cryptoConf);
+    _encryptor.init(Cipher.ENCRYPT_MODE, key, new IvParameterSpec(iv));
+    this.encryptor = _encryptor;
 
-    decryptor = CryptoCipherFactory.getCryptoCipher(cipher, cryptoConf);
-    decryptor.init(Cipher.DECRYPT_MODE, key, new IvParameterSpec(iv));
+    CryptoCipher _decryptor = CryptoCipherFactory.getCryptoCipher(cipher, 
cryptoConf);
+    _decryptor.init(Cipher.DECRYPT_MODE, key, new IvParameterSpec(iv));
+    this.decryptor = _decryptor;
   }
 
   /**
@@ -241,29 +249,52 @@ class AuthEngine implements Closeable {
     return new SecretKeySpec(key.getEncoded(), conf.keyAlgorithm());
   }
 
-  private byte[] doCipherOp(CryptoCipher cipher, byte[] in, boolean isFinal)
+  private byte[] doCipherOp(int mode, byte[] in, boolean isFinal)
     throws GeneralSecurityException {
 
-    Preconditions.checkState(cipher != null);
+    CryptoCipher cipher;
+    switch (mode) {
+      case Cipher.ENCRYPT_MODE:
+        cipher = encryptor;
+        break;
+      case Cipher.DECRYPT_MODE:
+        cipher = decryptor;
+        break;
+      default:
+        throw new IllegalArgumentException(String.valueOf(mode));
+    }
 
-    int scale = 1;
-    while (true) {
-      int size = in.length * scale;
-      byte[] buffer = new byte[size];
-      try {
-        int outSize = isFinal ? cipher.doFinal(in, 0, in.length, buffer, 0)
-          : cipher.update(in, 0, in.length, buffer, 0);
-        if (outSize != buffer.length) {
-          byte[] output = new byte[outSize];
-          System.arraycopy(buffer, 0, output, 0, output.length);
-          return output;
-        } else {
-          return buffer;
+    Preconditions.checkState(cipher != null, "Cipher is invalid because of 
previous error.");
+
+    try {
+      int scale = 1;
+      while (true) {
+        int size = in.length * scale;
+        byte[] buffer = new byte[size];
+        try {
+          int outSize = isFinal ? cipher.doFinal(in, 0, in.length, buffer, 0)
+            : cipher.update(in, 0, in.length, buffer, 0);
+          if (outSize != buffer.length) {
+            byte[] output = new byte[outSize];
+            System.arraycopy(buffer, 0, output, 0, output.length);
+            return output;
+          } else {
+            return buffer;
+          }
+        } catch (ShortBufferException e) {
+          // Try again with a bigger buffer.
+          scale *= 2;
         }
-      } catch (ShortBufferException e) {
-        // Try again with a bigger buffer.
-        scale *= 2;
       }
+    } catch (InternalError ie) {
+      // SPARK-25535. The commons-cryto library will throw InternalError if 
something goes wrong,
+      // and leave bad state behind in the Java wrappers, so it's not safe to 
use them afterwards.
+      if (mode == Cipher.ENCRYPT_MODE) {
+        this.encryptor = null;
+      } else {
+        this.decryptor = null;
+      }
+      throw ie;
     }
   }
 
diff --git 
a/common/network-common/src/main/java/org/apache/spark/network/crypto/TransportCipher.java
 
b/common/network-common/src/main/java/org/apache/spark/network/crypto/TransportCipher.java
index 0b674cc..1e0d27c 100644
--- 
a/common/network-common/src/main/java/org/apache/spark/network/crypto/TransportCipher.java
+++ 
b/common/network-common/src/main/java/org/apache/spark/network/crypto/TransportCipher.java
@@ -110,10 +110,12 @@ public class TransportCipher {
   static class EncryptionHandler extends ChannelOutboundHandlerAdapter {
     private final ByteArrayWritableChannel byteChannel;
     private final CryptoOutputStream cos;
+    private boolean isCipherValid;
 
     EncryptionHandler(TransportCipher cipher) throws IOException {
       byteChannel = new ByteArrayWritableChannel(STREAM_BUFFER_SIZE);
       cos = cipher.createOutputStream(byteChannel);
+      isCipherValid = true;
     }
 
     @Override
@@ -124,36 +126,61 @@ public class TransportCipher {
 
     @VisibleForTesting
     EncryptedMessage createEncryptedMessage(Object msg) {
-      return new EncryptedMessage(cos, msg, byteChannel);
+      return new EncryptedMessage(this, cos, msg, byteChannel);
     }
 
     @Override
     public void close(ChannelHandlerContext ctx, ChannelPromise promise) 
throws Exception {
       try {
-        cos.close();
+        if (isCipherValid) {
+          cos.close();
+        }
       } finally {
         super.close(ctx, promise);
       }
     }
+
+    /**
+     * SPARK-25535. Workaround for CRYPTO-141. Avoid further interaction with 
the underlying cipher
+     * after an error occurs.
+     */
+    void reportError() {
+      this.isCipherValid = false;
+    }
+
+    boolean isCipherValid() {
+      return isCipherValid;
+    }
   }
 
   private static class DecryptionHandler extends ChannelInboundHandlerAdapter {
     private final CryptoInputStream cis;
     private final ByteArrayReadableChannel byteChannel;
+    private boolean isCipherValid;
 
     DecryptionHandler(TransportCipher cipher) throws IOException {
       byteChannel = new ByteArrayReadableChannel();
       cis = cipher.createInputStream(byteChannel);
+      isCipherValid = true;
     }
 
     @Override
     public void channelRead(ChannelHandlerContext ctx, Object data) throws 
Exception {
+      if (!isCipherValid) {
+        throw new IOException("Cipher is in invalid state.");
+      }
       byteChannel.feedData((ByteBuf) data);
 
       byte[] decryptedData = new byte[byteChannel.readableBytes()];
       int offset = 0;
       while (offset < decryptedData.length) {
-        offset += cis.read(decryptedData, offset, decryptedData.length - 
offset);
+        // SPARK-25535: workaround for CRYPTO-141.
+        try {
+          offset += cis.read(decryptedData, offset, decryptedData.length - 
offset);
+        } catch (InternalError ie) {
+          isCipherValid = false;
+          throw ie;
+        }
       }
 
       ctx.fireChannelRead(Unpooled.wrappedBuffer(decryptedData, 0, 
decryptedData.length));
@@ -162,7 +189,9 @@ public class TransportCipher {
     @Override
     public void channelInactive(ChannelHandlerContext ctx) throws Exception {
       try {
-        cis.close();
+        if (isCipherValid) {
+          cis.close();
+        }
       } finally {
         super.channelInactive(ctx);
       }
@@ -175,8 +204,9 @@ public class TransportCipher {
     private final ByteBuf buf;
     private final FileRegion region;
     private final long count;
+    private final CryptoOutputStream cos;
+    private final EncryptionHandler handler;
     private long transferred;
-    private CryptoOutputStream cos;
 
     // Due to streaming issue CRYPTO-125: 
https://issues.apache.org/jira/browse/CRYPTO-125, it has
     // to utilize two helper ByteArrayWritableChannel for streaming. One is 
used to receive raw data
@@ -186,9 +216,14 @@ public class TransportCipher {
 
     private ByteBuffer currentEncrypted;
 
-    EncryptedMessage(CryptoOutputStream cos, Object msg, 
ByteArrayWritableChannel ch) {
+    EncryptedMessage(
+        EncryptionHandler handler,
+        CryptoOutputStream cos,
+        Object msg,
+        ByteArrayWritableChannel ch) {
       Preconditions.checkArgument(msg instanceof ByteBuf || msg instanceof 
FileRegion,
         "Unrecognized message type: %s", msg.getClass().getName());
+      this.handler = handler;
       this.isByteBuf = msg instanceof ByteBuf;
       this.buf = isByteBuf ? (ByteBuf) msg : null;
       this.region = isByteBuf ? null : (FileRegion) msg;
@@ -288,6 +323,9 @@ public class TransportCipher {
     }
 
     private void encryptMore() throws IOException {
+      if (!handler.isCipherValid()) {
+        throw new IOException("Cipher is in invalid state.");
+      }
       byteRawChannel.reset();
 
       if (isByteBuf) {
@@ -296,8 +334,14 @@ public class TransportCipher {
       } else {
         region.transferTo(byteRawChannel, region.transferred());
       }
-      cos.write(byteRawChannel.getData(), 0, byteRawChannel.length());
-      cos.flush();
+
+      try {
+        cos.write(byteRawChannel.getData(), 0, byteRawChannel.length());
+        cos.flush();
+      } catch (InternalError ie) {
+        handler.reportError();
+        throw ie;
+      }
 
       currentEncrypted = ByteBuffer.wrap(byteEncChannel.getData(),
         0, byteEncChannel.length());
diff --git 
a/common/network-common/src/test/java/org/apache/spark/network/crypto/AuthEngineSuite.java
 
b/common/network-common/src/test/java/org/apache/spark/network/crypto/AuthEngineSuite.java
index 46b6305..382b733 100644
--- 
a/common/network-common/src/test/java/org/apache/spark/network/crypto/AuthEngineSuite.java
+++ 
b/common/network-common/src/test/java/org/apache/spark/network/crypto/AuthEngineSuite.java
@@ -20,10 +20,13 @@ package org.apache.spark.network.crypto;
 import java.nio.ByteBuffer;
 import java.nio.channels.WritableByteChannel;
 import java.util.Arrays;
+import java.util.Map;
+import java.security.InvalidKeyException;
 import java.util.Random;
 
 import static java.nio.charset.StandardCharsets.UTF_8;
 
+import com.google.common.collect.ImmutableMap;
 import io.netty.buffer.ByteBuf;
 import io.netty.buffer.Unpooled;
 import io.netty.channel.FileRegion;
@@ -189,4 +192,18 @@ public class AuthEngineSuite {
       server.close();
     }
   }
+
+  @Test(expected = InvalidKeyException.class)
+  public void testBadKeySize() throws Exception {
+    Map<String, String> mconf = 
ImmutableMap.of("spark.network.crypto.keyLength", "42");
+    TransportConf conf = new TransportConf("rpc", new 
MapConfigProvider(mconf));
+
+    try (AuthEngine engine = new AuthEngine("appId", "secret", conf)) {
+      engine.challenge();
+      fail("Should have failed to create challenge message.");
+
+      // Call close explicitly to make sure it's idempotent.
+      engine.close();
+    }
+  }
 }
diff --git 
a/core/src/main/scala/org/apache/spark/security/CryptoStreamUtils.scala 
b/core/src/main/scala/org/apache/spark/security/CryptoStreamUtils.scala
index 0062197..18b735b 100644
--- a/core/src/main/scala/org/apache/spark/security/CryptoStreamUtils.scala
+++ b/core/src/main/scala/org/apache/spark/security/CryptoStreamUtils.scala
@@ -16,7 +16,7 @@
  */
 package org.apache.spark.security
 
-import java.io.{InputStream, OutputStream}
+import java.io.{Closeable, InputStream, IOException, OutputStream}
 import java.nio.ByteBuffer
 import java.nio.channels.{ReadableByteChannel, WritableByteChannel}
 import java.util.Properties
@@ -54,8 +54,10 @@ private[spark] object CryptoStreamUtils extends Logging {
     val params = new CryptoParams(key, sparkConf)
     val iv = createInitializationVector(params.conf)
     os.write(iv)
-    new CryptoOutputStream(params.transformation, params.conf, os, 
params.keySpec,
-      new IvParameterSpec(iv))
+    new ErrorHandlingOutputStream(
+      new CryptoOutputStream(params.transformation, params.conf, os, 
params.keySpec,
+        new IvParameterSpec(iv)),
+      os)
   }
 
   /**
@@ -70,8 +72,10 @@ private[spark] object CryptoStreamUtils extends Logging {
     val helper = new CryptoHelperChannel(channel)
 
     helper.write(ByteBuffer.wrap(iv))
-    new CryptoOutputStream(params.transformation, params.conf, helper, 
params.keySpec,
-      new IvParameterSpec(iv))
+    new ErrorHandlingWritableChannel(
+      new CryptoOutputStream(params.transformation, params.conf, helper, 
params.keySpec,
+        new IvParameterSpec(iv)),
+      helper)
   }
 
   /**
@@ -84,8 +88,10 @@ private[spark] object CryptoStreamUtils extends Logging {
     val iv = new Array[Byte](IV_LENGTH_IN_BYTES)
     ByteStreams.readFully(is, iv)
     val params = new CryptoParams(key, sparkConf)
-    new CryptoInputStream(params.transformation, params.conf, is, 
params.keySpec,
-      new IvParameterSpec(iv))
+    new ErrorHandlingInputStream(
+      new CryptoInputStream(params.transformation, params.conf, is, 
params.keySpec,
+        new IvParameterSpec(iv)),
+      is)
   }
 
   /**
@@ -100,8 +106,10 @@ private[spark] object CryptoStreamUtils extends Logging {
     JavaUtils.readFully(channel, buf)
 
     val params = new CryptoParams(key, sparkConf)
-    new CryptoInputStream(params.transformation, params.conf, channel, 
params.keySpec,
-      new IvParameterSpec(iv))
+    new ErrorHandlingReadableChannel(
+      new CryptoInputStream(params.transformation, params.conf, channel, 
params.keySpec,
+        new IvParameterSpec(iv)),
+      channel)
   }
 
   def toCryptoConf(conf: SparkConf): Properties = {
@@ -157,6 +165,117 @@ private[spark] object CryptoStreamUtils extends Logging {
 
   }
 
+  /**
+   * SPARK-25535. The commons-cryto library will throw InternalError if 
something goes
+   * wrong, and leave bad state behind in the Java wrappers, so it's not safe 
to use them
+   * afterwards. This wrapper detects that situation and avoids further calls 
into the
+   * commons-crypto code, while still allowing the underlying streams to be 
closed.
+   *
+   * This should be removed once CRYPTO-141 is fixed (and Spark upgrades its 
commons-crypto
+   * dependency).
+   */
+  trait BaseErrorHandler extends Closeable {
+
+    private var closed = false
+
+    /** The encrypted stream that may get into an unhealthy state. */
+    protected def cipherStream: Closeable
+
+    /**
+     * The underlying stream that is being wrapped by the encrypted stream, so 
that it can be
+     * closed even if there's an error in the crypto layer.
+     */
+    protected def original: Closeable
+
+    protected def safeCall[T](fn: => T): T = {
+      if (closed) {
+        throw new IOException("Cipher stream is closed.")
+      }
+      try {
+        fn
+      } catch {
+        case ie: InternalError =>
+          closed = true
+          original.close()
+          throw ie
+      }
+    }
+
+    override def close(): Unit = {
+      if (!closed) {
+        cipherStream.close()
+      }
+    }
+
+  }
+
+  // Visible for testing.
+  class ErrorHandlingReadableChannel(
+      protected val cipherStream: ReadableByteChannel,
+      protected val original: ReadableByteChannel)
+    extends ReadableByteChannel with BaseErrorHandler {
+
+    override def read(src: ByteBuffer): Int = safeCall {
+      cipherStream.read(src)
+    }
+
+    override def isOpen(): Boolean = cipherStream.isOpen()
+
+  }
+
+  private class ErrorHandlingInputStream(
+      protected val cipherStream: InputStream,
+      protected val original: InputStream)
+    extends InputStream with BaseErrorHandler {
+
+    override def read(b: Array[Byte]): Int = safeCall {
+      cipherStream.read(b)
+    }
+
+    override def read(b: Array[Byte], off: Int, len: Int): Int = safeCall {
+      cipherStream.read(b, off, len)
+    }
+
+    override def read(): Int = safeCall {
+      cipherStream.read()
+    }
+  }
+
+  private class ErrorHandlingWritableChannel(
+      protected val cipherStream: WritableByteChannel,
+      protected val original: WritableByteChannel)
+    extends WritableByteChannel with BaseErrorHandler {
+
+    override def write(src: ByteBuffer): Int = safeCall {
+      cipherStream.write(src)
+    }
+
+    override def isOpen(): Boolean = cipherStream.isOpen()
+
+  }
+
+  private class ErrorHandlingOutputStream(
+      protected val cipherStream: OutputStream,
+      protected val original: OutputStream)
+    extends OutputStream with BaseErrorHandler {
+
+    override def flush(): Unit = safeCall {
+      cipherStream.flush()
+    }
+
+    override def write(b: Array[Byte]): Unit = safeCall {
+      cipherStream.write(b)
+    }
+
+    override def write(b: Array[Byte], off: Int, len: Int): Unit = safeCall {
+      cipherStream.write(b, off, len)
+    }
+
+    override def write(b: Int): Unit = safeCall {
+      cipherStream.write(b)
+    }
+  }
+
   private class CryptoParams(key: Array[Byte], sparkConf: SparkConf) {
 
     val keySpec = new SecretKeySpec(key, "AES")
diff --git 
a/core/src/test/scala/org/apache/spark/security/CryptoStreamUtilsSuite.scala 
b/core/src/test/scala/org/apache/spark/security/CryptoStreamUtilsSuite.scala
index 78f618f..0d3611c 100644
--- a/core/src/test/scala/org/apache/spark/security/CryptoStreamUtilsSuite.scala
+++ b/core/src/test/scala/org/apache/spark/security/CryptoStreamUtilsSuite.scala
@@ -16,13 +16,16 @@
  */
 package org.apache.spark.security
 
-import java.io.{ByteArrayInputStream, ByteArrayOutputStream, FileInputStream, 
FileOutputStream}
-import java.nio.channels.Channels
+import java.io._
+import java.nio.ByteBuffer
+import java.nio.channels.{Channels, ReadableByteChannel}
 import java.nio.charset.StandardCharsets.UTF_8
 import java.nio.file.Files
 import java.util.{Arrays, Random, UUID}
 
 import com.google.common.io.ByteStreams
+import org.mockito.Matchers.any
+import org.mockito.Mockito._
 
 import org.apache.spark._
 import org.apache.spark.internal.config._
@@ -164,6 +167,36 @@ class CryptoStreamUtilsSuite extends SparkFunSuite {
     }
   }
 
+  test("error handling wrapper") {
+    val wrapped = mock(classOf[ReadableByteChannel])
+    val decrypted = mock(classOf[ReadableByteChannel])
+    val errorHandler = new 
CryptoStreamUtils.ErrorHandlingReadableChannel(decrypted, wrapped)
+
+    when(decrypted.read(any(classOf[ByteBuffer])))
+      .thenThrow(new IOException())
+      .thenThrow(new InternalError())
+      .thenReturn(1)
+
+    val out = ByteBuffer.allocate(1)
+    intercept[IOException] {
+      errorHandler.read(out)
+    }
+    intercept[InternalError] {
+      errorHandler.read(out)
+    }
+
+    val e = intercept[IOException] {
+      errorHandler.read(out)
+    }
+    assert(e.getMessage().contains("is closed"))
+    errorHandler.close()
+
+    verify(decrypted, times(2)).read(any(classOf[ByteBuffer]))
+    verify(wrapped, never()).read(any(classOf[ByteBuffer]))
+    verify(decrypted, never()).close()
+    verify(wrapped, times(1)).close()
+  }
+
   private def createConf(extra: (String, String)*): SparkConf = {
     val conf = new SparkConf()
     extra.foreach { case (k, v) => conf.set(k, v) }


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to