Repository: spark Updated Branches: refs/heads/master deb9588b2 -> 3eee9e024
[SPARK-25535][CORE] Work around bad error handling in commons-crypto. The commons-crypto library does some questionable error handling internally, which can lead to JVM crashes if some call into native code fails and cleans up state it should not. While the library is not fixed, this change adds some workarounds in Spark code so that when an error is detected in the commons-crypto side, Spark avoids calling into the library further. Tested with existing and added unit tests. Closes #22557 from vanzin/SPARK-25535. Authored-by: Marcelo Vanzin <van...@cloudera.com> Signed-off-by: Imran Rashid <iras...@cloudera.com> Project: http://git-wip-us.apache.org/repos/asf/spark/repo Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/3eee9e02 Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/3eee9e02 Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/3eee9e02 Branch: refs/heads/master Commit: 3eee9e02463e10570a29fad00823c953debd945e Parents: deb9588 Author: Marcelo Vanzin <van...@cloudera.com> Authored: Tue Oct 9 09:27:08 2018 -0500 Committer: Imran Rashid <iras...@cloudera.com> Committed: Tue Oct 9 09:27:08 2018 -0500 ---------------------------------------------------------------------- .../apache/spark/network/crypto/AuthEngine.java | 95 ++++++++----- .../spark/network/crypto/TransportCipher.java | 60 ++++++-- .../spark/network/crypto/AuthEngineSuite.java | 17 +++ .../spark/security/CryptoStreamUtils.scala | 137 +++++++++++++++++-- .../spark/security/CryptoStreamUtilsSuite.scala | 37 ++++- 5 files changed, 295 insertions(+), 51 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/spark/blob/3eee9e02/common/network-common/src/main/java/org/apache/spark/network/crypto/AuthEngine.java ---------------------------------------------------------------------- diff --git a/common/network-common/src/main/java/org/apache/spark/network/crypto/AuthEngine.java b/common/network-common/src/main/java/org/apache/spark/network/crypto/AuthEngine.java index 056505e..64fdb32 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/crypto/AuthEngine.java +++ b/common/network-common/src/main/java/org/apache/spark/network/crypto/AuthEngine.java @@ -159,15 +159,21 @@ class AuthEngine implements Closeable { // accurately report the errors when they happen. RuntimeException error = null; byte[] dummy = new byte[8]; - try { - doCipherOp(encryptor, dummy, true); - } catch (Exception e) { - error = new RuntimeException(e); + if (encryptor != null) { + try { + doCipherOp(Cipher.ENCRYPT_MODE, dummy, true); + } catch (Exception e) { + error = new RuntimeException(e); + } + encryptor = null; } - try { - doCipherOp(decryptor, dummy, true); - } catch (Exception e) { - error = new RuntimeException(e); + if (decryptor != null) { + try { + doCipherOp(Cipher.DECRYPT_MODE, dummy, true); + } catch (Exception e) { + error = new RuntimeException(e); + } + decryptor = null; } random.close(); @@ -189,11 +195,11 @@ class AuthEngine implements Closeable { } private byte[] decrypt(byte[] in) throws GeneralSecurityException { - return doCipherOp(decryptor, in, false); + return doCipherOp(Cipher.DECRYPT_MODE, in, false); } private byte[] encrypt(byte[] in) throws GeneralSecurityException { - return doCipherOp(encryptor, in, false); + return doCipherOp(Cipher.ENCRYPT_MODE, in, false); } private void initializeForAuth(String cipher, byte[] nonce, SecretKeySpec key) @@ -205,11 +211,13 @@ class AuthEngine implements Closeable { byte[] iv = new byte[conf.ivLength()]; System.arraycopy(nonce, 0, iv, 0, Math.min(nonce.length, iv.length)); - encryptor = CryptoCipherFactory.getCryptoCipher(cipher, cryptoConf); - encryptor.init(Cipher.ENCRYPT_MODE, key, new IvParameterSpec(iv)); + CryptoCipher _encryptor = CryptoCipherFactory.getCryptoCipher(cipher, cryptoConf); + _encryptor.init(Cipher.ENCRYPT_MODE, key, new IvParameterSpec(iv)); + this.encryptor = _encryptor; - decryptor = CryptoCipherFactory.getCryptoCipher(cipher, cryptoConf); - decryptor.init(Cipher.DECRYPT_MODE, key, new IvParameterSpec(iv)); + CryptoCipher _decryptor = CryptoCipherFactory.getCryptoCipher(cipher, cryptoConf); + _decryptor.init(Cipher.DECRYPT_MODE, key, new IvParameterSpec(iv)); + this.decryptor = _decryptor; } /** @@ -241,29 +249,52 @@ class AuthEngine implements Closeable { return new SecretKeySpec(key.getEncoded(), conf.keyAlgorithm()); } - private byte[] doCipherOp(CryptoCipher cipher, byte[] in, boolean isFinal) + private byte[] doCipherOp(int mode, byte[] in, boolean isFinal) throws GeneralSecurityException { - Preconditions.checkState(cipher != null); + CryptoCipher cipher; + switch (mode) { + case Cipher.ENCRYPT_MODE: + cipher = encryptor; + break; + case Cipher.DECRYPT_MODE: + cipher = decryptor; + break; + default: + throw new IllegalArgumentException(String.valueOf(mode)); + } - int scale = 1; - while (true) { - int size = in.length * scale; - byte[] buffer = new byte[size]; - try { - int outSize = isFinal ? cipher.doFinal(in, 0, in.length, buffer, 0) - : cipher.update(in, 0, in.length, buffer, 0); - if (outSize != buffer.length) { - byte[] output = new byte[outSize]; - System.arraycopy(buffer, 0, output, 0, output.length); - return output; - } else { - return buffer; + Preconditions.checkState(cipher != null, "Cipher is invalid because of previous error."); + + try { + int scale = 1; + while (true) { + int size = in.length * scale; + byte[] buffer = new byte[size]; + try { + int outSize = isFinal ? cipher.doFinal(in, 0, in.length, buffer, 0) + : cipher.update(in, 0, in.length, buffer, 0); + if (outSize != buffer.length) { + byte[] output = new byte[outSize]; + System.arraycopy(buffer, 0, output, 0, output.length); + return output; + } else { + return buffer; + } + } catch (ShortBufferException e) { + // Try again with a bigger buffer. + scale *= 2; } - } catch (ShortBufferException e) { - // Try again with a bigger buffer. - scale *= 2; } + } catch (InternalError ie) { + // SPARK-25535. The commons-cryto library will throw InternalError if something goes wrong, + // and leave bad state behind in the Java wrappers, so it's not safe to use them afterwards. + if (mode == Cipher.ENCRYPT_MODE) { + this.encryptor = null; + } else { + this.decryptor = null; + } + throw ie; } } http://git-wip-us.apache.org/repos/asf/spark/blob/3eee9e02/common/network-common/src/main/java/org/apache/spark/network/crypto/TransportCipher.java ---------------------------------------------------------------------- diff --git a/common/network-common/src/main/java/org/apache/spark/network/crypto/TransportCipher.java b/common/network-common/src/main/java/org/apache/spark/network/crypto/TransportCipher.java index b64e4b7..2745052 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/crypto/TransportCipher.java +++ b/common/network-common/src/main/java/org/apache/spark/network/crypto/TransportCipher.java @@ -107,45 +107,72 @@ public class TransportCipher { private static class EncryptionHandler extends ChannelOutboundHandlerAdapter { private final ByteArrayWritableChannel byteChannel; private final CryptoOutputStream cos; + private boolean isCipherValid; EncryptionHandler(TransportCipher cipher) throws IOException { byteChannel = new ByteArrayWritableChannel(STREAM_BUFFER_SIZE); cos = cipher.createOutputStream(byteChannel); + isCipherValid = true; } @Override public void write(ChannelHandlerContext ctx, Object msg, ChannelPromise promise) throws Exception { - ctx.write(new EncryptedMessage(cos, msg, byteChannel), promise); + ctx.write(new EncryptedMessage(this, cos, msg, byteChannel), promise); } @Override public void close(ChannelHandlerContext ctx, ChannelPromise promise) throws Exception { try { - cos.close(); + if (isCipherValid) { + cos.close(); + } } finally { super.close(ctx, promise); } } + + /** + * SPARK-25535. Workaround for CRYPTO-141. Avoid further interaction with the underlying cipher + * after an error occurs. + */ + void reportError() { + this.isCipherValid = false; + } + + boolean isCipherValid() { + return isCipherValid; + } } private static class DecryptionHandler extends ChannelInboundHandlerAdapter { private final CryptoInputStream cis; private final ByteArrayReadableChannel byteChannel; + private boolean isCipherValid; DecryptionHandler(TransportCipher cipher) throws IOException { byteChannel = new ByteArrayReadableChannel(); cis = cipher.createInputStream(byteChannel); + isCipherValid = true; } @Override public void channelRead(ChannelHandlerContext ctx, Object data) throws Exception { + if (!isCipherValid) { + throw new IOException("Cipher is in invalid state."); + } byteChannel.feedData((ByteBuf) data); byte[] decryptedData = new byte[byteChannel.readableBytes()]; int offset = 0; while (offset < decryptedData.length) { - offset += cis.read(decryptedData, offset, decryptedData.length - offset); + // SPARK-25535: workaround for CRYPTO-141. + try { + offset += cis.read(decryptedData, offset, decryptedData.length - offset); + } catch (InternalError ie) { + isCipherValid = false; + throw ie; + } } ctx.fireChannelRead(Unpooled.wrappedBuffer(decryptedData, 0, decryptedData.length)); @@ -154,7 +181,9 @@ public class TransportCipher { @Override public void channelInactive(ChannelHandlerContext ctx) throws Exception { try { - cis.close(); + if (isCipherValid) { + cis.close(); + } } finally { super.channelInactive(ctx); } @@ -165,8 +194,9 @@ public class TransportCipher { private final boolean isByteBuf; private final ByteBuf buf; private final FileRegion region; + private final CryptoOutputStream cos; + private final EncryptionHandler handler; private long transferred; - private CryptoOutputStream cos; // Due to streaming issue CRYPTO-125: https://issues.apache.org/jira/browse/CRYPTO-125, it has // to utilize two helper ByteArrayWritableChannel for streaming. One is used to receive raw data @@ -176,9 +206,14 @@ public class TransportCipher { private ByteBuffer currentEncrypted; - EncryptedMessage(CryptoOutputStream cos, Object msg, ByteArrayWritableChannel ch) { + EncryptedMessage( + EncryptionHandler handler, + CryptoOutputStream cos, + Object msg, + ByteArrayWritableChannel ch) { Preconditions.checkArgument(msg instanceof ByteBuf || msg instanceof FileRegion, "Unrecognized message type: %s", msg.getClass().getName()); + this.handler = handler; this.isByteBuf = msg instanceof ByteBuf; this.buf = isByteBuf ? (ByteBuf) msg : null; this.region = isByteBuf ? null : (FileRegion) msg; @@ -261,6 +296,9 @@ public class TransportCipher { } private void encryptMore() throws IOException { + if (!handler.isCipherValid()) { + throw new IOException("Cipher is in invalid state."); + } byteRawChannel.reset(); if (isByteBuf) { @@ -269,8 +307,14 @@ public class TransportCipher { } else { region.transferTo(byteRawChannel, region.transferred()); } - cos.write(byteRawChannel.getData(), 0, byteRawChannel.length()); - cos.flush(); + + try { + cos.write(byteRawChannel.getData(), 0, byteRawChannel.length()); + cos.flush(); + } catch (InternalError ie) { + handler.reportError(); + throw ie; + } currentEncrypted = ByteBuffer.wrap(byteEncChannel.getData(), 0, byteEncChannel.length()); http://git-wip-us.apache.org/repos/asf/spark/blob/3eee9e02/common/network-common/src/test/java/org/apache/spark/network/crypto/AuthEngineSuite.java ---------------------------------------------------------------------- diff --git a/common/network-common/src/test/java/org/apache/spark/network/crypto/AuthEngineSuite.java b/common/network-common/src/test/java/org/apache/spark/network/crypto/AuthEngineSuite.java index a3519fe..c0aa298 100644 --- a/common/network-common/src/test/java/org/apache/spark/network/crypto/AuthEngineSuite.java +++ b/common/network-common/src/test/java/org/apache/spark/network/crypto/AuthEngineSuite.java @@ -18,8 +18,11 @@ package org.apache.spark.network.crypto; import java.util.Arrays; +import java.util.Map; +import java.security.InvalidKeyException; import static java.nio.charset.StandardCharsets.UTF_8; +import com.google.common.collect.ImmutableMap; import org.junit.BeforeClass; import org.junit.Test; import static org.junit.Assert.*; @@ -104,4 +107,18 @@ public class AuthEngineSuite { challenge.cipher, challenge.keyLength, challenge.nonce, badChallenge)); } + @Test(expected = InvalidKeyException.class) + public void testBadKeySize() throws Exception { + Map<String, String> mconf = ImmutableMap.of("spark.network.crypto.keyLength", "42"); + TransportConf conf = new TransportConf("rpc", new MapConfigProvider(mconf)); + + try (AuthEngine engine = new AuthEngine("appId", "secret", conf)) { + engine.challenge(); + fail("Should have failed to create challenge message."); + + // Call close explicitly to make sure it's idempotent. + engine.close(); + } + } + } http://git-wip-us.apache.org/repos/asf/spark/blob/3eee9e02/core/src/main/scala/org/apache/spark/security/CryptoStreamUtils.scala ---------------------------------------------------------------------- diff --git a/core/src/main/scala/org/apache/spark/security/CryptoStreamUtils.scala b/core/src/main/scala/org/apache/spark/security/CryptoStreamUtils.scala index 0062197..18b735b 100644 --- a/core/src/main/scala/org/apache/spark/security/CryptoStreamUtils.scala +++ b/core/src/main/scala/org/apache/spark/security/CryptoStreamUtils.scala @@ -16,7 +16,7 @@ */ package org.apache.spark.security -import java.io.{InputStream, OutputStream} +import java.io.{Closeable, InputStream, IOException, OutputStream} import java.nio.ByteBuffer import java.nio.channels.{ReadableByteChannel, WritableByteChannel} import java.util.Properties @@ -54,8 +54,10 @@ private[spark] object CryptoStreamUtils extends Logging { val params = new CryptoParams(key, sparkConf) val iv = createInitializationVector(params.conf) os.write(iv) - new CryptoOutputStream(params.transformation, params.conf, os, params.keySpec, - new IvParameterSpec(iv)) + new ErrorHandlingOutputStream( + new CryptoOutputStream(params.transformation, params.conf, os, params.keySpec, + new IvParameterSpec(iv)), + os) } /** @@ -70,8 +72,10 @@ private[spark] object CryptoStreamUtils extends Logging { val helper = new CryptoHelperChannel(channel) helper.write(ByteBuffer.wrap(iv)) - new CryptoOutputStream(params.transformation, params.conf, helper, params.keySpec, - new IvParameterSpec(iv)) + new ErrorHandlingWritableChannel( + new CryptoOutputStream(params.transformation, params.conf, helper, params.keySpec, + new IvParameterSpec(iv)), + helper) } /** @@ -84,8 +88,10 @@ private[spark] object CryptoStreamUtils extends Logging { val iv = new Array[Byte](IV_LENGTH_IN_BYTES) ByteStreams.readFully(is, iv) val params = new CryptoParams(key, sparkConf) - new CryptoInputStream(params.transformation, params.conf, is, params.keySpec, - new IvParameterSpec(iv)) + new ErrorHandlingInputStream( + new CryptoInputStream(params.transformation, params.conf, is, params.keySpec, + new IvParameterSpec(iv)), + is) } /** @@ -100,8 +106,10 @@ private[spark] object CryptoStreamUtils extends Logging { JavaUtils.readFully(channel, buf) val params = new CryptoParams(key, sparkConf) - new CryptoInputStream(params.transformation, params.conf, channel, params.keySpec, - new IvParameterSpec(iv)) + new ErrorHandlingReadableChannel( + new CryptoInputStream(params.transformation, params.conf, channel, params.keySpec, + new IvParameterSpec(iv)), + channel) } def toCryptoConf(conf: SparkConf): Properties = { @@ -157,6 +165,117 @@ private[spark] object CryptoStreamUtils extends Logging { } + /** + * SPARK-25535. The commons-cryto library will throw InternalError if something goes + * wrong, and leave bad state behind in the Java wrappers, so it's not safe to use them + * afterwards. This wrapper detects that situation and avoids further calls into the + * commons-crypto code, while still allowing the underlying streams to be closed. + * + * This should be removed once CRYPTO-141 is fixed (and Spark upgrades its commons-crypto + * dependency). + */ + trait BaseErrorHandler extends Closeable { + + private var closed = false + + /** The encrypted stream that may get into an unhealthy state. */ + protected def cipherStream: Closeable + + /** + * The underlying stream that is being wrapped by the encrypted stream, so that it can be + * closed even if there's an error in the crypto layer. + */ + protected def original: Closeable + + protected def safeCall[T](fn: => T): T = { + if (closed) { + throw new IOException("Cipher stream is closed.") + } + try { + fn + } catch { + case ie: InternalError => + closed = true + original.close() + throw ie + } + } + + override def close(): Unit = { + if (!closed) { + cipherStream.close() + } + } + + } + + // Visible for testing. + class ErrorHandlingReadableChannel( + protected val cipherStream: ReadableByteChannel, + protected val original: ReadableByteChannel) + extends ReadableByteChannel with BaseErrorHandler { + + override def read(src: ByteBuffer): Int = safeCall { + cipherStream.read(src) + } + + override def isOpen(): Boolean = cipherStream.isOpen() + + } + + private class ErrorHandlingInputStream( + protected val cipherStream: InputStream, + protected val original: InputStream) + extends InputStream with BaseErrorHandler { + + override def read(b: Array[Byte]): Int = safeCall { + cipherStream.read(b) + } + + override def read(b: Array[Byte], off: Int, len: Int): Int = safeCall { + cipherStream.read(b, off, len) + } + + override def read(): Int = safeCall { + cipherStream.read() + } + } + + private class ErrorHandlingWritableChannel( + protected val cipherStream: WritableByteChannel, + protected val original: WritableByteChannel) + extends WritableByteChannel with BaseErrorHandler { + + override def write(src: ByteBuffer): Int = safeCall { + cipherStream.write(src) + } + + override def isOpen(): Boolean = cipherStream.isOpen() + + } + + private class ErrorHandlingOutputStream( + protected val cipherStream: OutputStream, + protected val original: OutputStream) + extends OutputStream with BaseErrorHandler { + + override def flush(): Unit = safeCall { + cipherStream.flush() + } + + override def write(b: Array[Byte]): Unit = safeCall { + cipherStream.write(b) + } + + override def write(b: Array[Byte], off: Int, len: Int): Unit = safeCall { + cipherStream.write(b, off, len) + } + + override def write(b: Int): Unit = safeCall { + cipherStream.write(b) + } + } + private class CryptoParams(key: Array[Byte], sparkConf: SparkConf) { val keySpec = new SecretKeySpec(key, "AES") http://git-wip-us.apache.org/repos/asf/spark/blob/3eee9e02/core/src/test/scala/org/apache/spark/security/CryptoStreamUtilsSuite.scala ---------------------------------------------------------------------- diff --git a/core/src/test/scala/org/apache/spark/security/CryptoStreamUtilsSuite.scala b/core/src/test/scala/org/apache/spark/security/CryptoStreamUtilsSuite.scala index 78f618f..0d3611c 100644 --- a/core/src/test/scala/org/apache/spark/security/CryptoStreamUtilsSuite.scala +++ b/core/src/test/scala/org/apache/spark/security/CryptoStreamUtilsSuite.scala @@ -16,13 +16,16 @@ */ package org.apache.spark.security -import java.io.{ByteArrayInputStream, ByteArrayOutputStream, FileInputStream, FileOutputStream} -import java.nio.channels.Channels +import java.io._ +import java.nio.ByteBuffer +import java.nio.channels.{Channels, ReadableByteChannel} import java.nio.charset.StandardCharsets.UTF_8 import java.nio.file.Files import java.util.{Arrays, Random, UUID} import com.google.common.io.ByteStreams +import org.mockito.Matchers.any +import org.mockito.Mockito._ import org.apache.spark._ import org.apache.spark.internal.config._ @@ -164,6 +167,36 @@ class CryptoStreamUtilsSuite extends SparkFunSuite { } } + test("error handling wrapper") { + val wrapped = mock(classOf[ReadableByteChannel]) + val decrypted = mock(classOf[ReadableByteChannel]) + val errorHandler = new CryptoStreamUtils.ErrorHandlingReadableChannel(decrypted, wrapped) + + when(decrypted.read(any(classOf[ByteBuffer]))) + .thenThrow(new IOException()) + .thenThrow(new InternalError()) + .thenReturn(1) + + val out = ByteBuffer.allocate(1) + intercept[IOException] { + errorHandler.read(out) + } + intercept[InternalError] { + errorHandler.read(out) + } + + val e = intercept[IOException] { + errorHandler.read(out) + } + assert(e.getMessage().contains("is closed")) + errorHandler.close() + + verify(decrypted, times(2)).read(any(classOf[ByteBuffer])) + verify(wrapped, never()).read(any(classOf[ByteBuffer])) + verify(decrypted, never()).close() + verify(wrapped, times(1)).close() + } + private def createConf(extra: (String, String)*): SparkConf = { val conf = new SparkConf() extra.foreach { case (k, v) => conf.set(k, v) } --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org