This is an automated email from the ASF dual-hosted git repository. srowen pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push: new ef8fb9b [SPARK-38694][TESTS] Simplify Java UT code with Junit `assertThrows` Api ef8fb9b is described below commit ef8fb9b9d84b6adfe5a4e03b6e775e709d624144 Author: yangjie01 <yangji...@baidu.com> AuthorDate: Wed Mar 30 18:32:37 2022 -0500 [SPARK-38694][TESTS] Simplify Java UT code with Junit `assertThrows` Api ### What changes were proposed in this pull request? There are some code patterns in Spark Java UTs: ```java Test public void testAuthReplay() throws Exception { try { doSomeOperation(); fail("Should have failed"); } catch (Exception e) { assertTrue(doExceptionCheck(e)); } } ``` or ```java Test(expected = SomeException.class) public void testAuthReplay() throws Exception { try { doSomeOperation(); fail("Should have failed"); } catch (Exception e) { assertTrue(doExceptionCheck(e)); throw e; } } ``` This pr use Junit `assertThrows` Api to simplify the similar patterns. ### Why are the changes needed? Simplify code. ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? Pass GA Closes #36008 from LuciferYang/SPARK-38694. Authored-by: yangjie01 <yangji...@baidu.com> Signed-off-by: Sean Owen <sro...@gmail.com> --- .../spark/util/kvstore/InMemoryStoreSuite.java | 21 +- .../apache/spark/util/kvstore/LevelDBSuite.java | 21 +- .../apache/spark/util/kvstore/RocksDBSuite.java | 21 +- .../spark/network/crypto/AuthIntegrationSuite.java | 39 +- .../spark/network/crypto/TransportCipherSuite.java | 21 +- .../apache/spark/network/sasl/SparkSaslSuite.java | 41 +- .../server/OneForOneStreamManagerSuite.java | 23 +- .../spark/network/sasl/SaslIntegrationSuite.java | 37 +- .../network/shuffle/ExternalBlockHandlerSuite.java | 14 +- .../shuffle/ExternalShuffleBlockResolverSuite.java | 17 +- .../shuffle/ExternalShuffleSecuritySuite.java | 16 +- .../shuffle/OneForOneBlockFetcherSuite.java | 14 +- .../shuffle/RemoteBlockPushResolverSuite.java | 464 +++++++++------------ .../apache/spark/unsafe/types/UTF8StringSuite.java | 8 +- .../apache/spark/launcher/SparkLauncherSuite.java | 15 +- .../shuffle/sort/PackedRecordPointerSuite.java | 14 +- .../unsafe/map/AbstractBytesToBytesMapSuite.java | 40 +- .../java/test/org/apache/spark/JavaAPISuite.java | 16 +- .../spark/launcher/CommandBuilderUtilsSuite.java | 7 +- .../apache/spark/launcher/LauncherServerSuite.java | 14 +- .../JavaRandomForestClassifierSuite.java | 8 +- .../regression/JavaRandomForestRegressorSuite.java | 8 +- .../spark/ml/util/JavaDefaultReadWriteSuite.java | 8 +- .../expressions/RowBasedKeyValueBatchSuite.java | 60 +-- .../spark/sql/JavaBeanDeserializationSuite.java | 15 +- .../spark/sql/JavaColumnExpressionSuite.java | 16 +- 26 files changed, 317 insertions(+), 661 deletions(-) diff --git a/common/kvstore/src/test/java/org/apache/spark/util/kvstore/InMemoryStoreSuite.java b/common/kvstore/src/test/java/org/apache/spark/util/kvstore/InMemoryStoreSuite.java index 198b6e8..b2acd1a 100644 --- a/common/kvstore/src/test/java/org/apache/spark/util/kvstore/InMemoryStoreSuite.java +++ b/common/kvstore/src/test/java/org/apache/spark/util/kvstore/InMemoryStoreSuite.java @@ -34,24 +34,14 @@ public class InMemoryStoreSuite { t.id = "id"; t.name = "name"; - try { - store.read(CustomType1.class, t.key); - fail("Expected exception for non-existent object."); - } catch (NoSuchElementException nsee) { - // Expected. - } + assertThrows(NoSuchElementException.class, () -> store.read(CustomType1.class, t.key)); store.write(t); assertEquals(t, store.read(t.getClass(), t.key)); assertEquals(1L, store.count(t.getClass())); store.delete(t.getClass(), t.key); - try { - store.read(t.getClass(), t.key); - fail("Expected exception for deleted object."); - } catch (NoSuchElementException nsee) { - // Expected. - } + assertThrows(NoSuchElementException.class, () -> store.read(t.getClass(), t.key)); } @Test @@ -78,12 +68,7 @@ public class InMemoryStoreSuite { store.delete(t1.getClass(), t1.key); assertEquals(t2, store.read(t2.getClass(), t2.key)); store.delete(t2.getClass(), t2.key); - try { - store.read(t2.getClass(), t2.key); - fail("Expected exception for deleted object."); - } catch (NoSuchElementException nsee) { - // Expected. - } + assertThrows(NoSuchElementException.class, () -> store.read(t2.getClass(), t2.key)); } @Test diff --git a/common/kvstore/src/test/java/org/apache/spark/util/kvstore/LevelDBSuite.java b/common/kvstore/src/test/java/org/apache/spark/util/kvstore/LevelDBSuite.java index c43c9b1..a7a2148 100644 --- a/common/kvstore/src/test/java/org/apache/spark/util/kvstore/LevelDBSuite.java +++ b/common/kvstore/src/test/java/org/apache/spark/util/kvstore/LevelDBSuite.java @@ -71,36 +71,21 @@ public class LevelDBSuite { db.close(); db = null; - try { - db = new LevelDB(dbpath); - fail("Should have failed version check."); - } catch (UnsupportedStoreVersionException e) { - // Expected. - } + assertThrows(UnsupportedStoreVersionException.class, () -> db = new LevelDB(dbpath)); } @Test public void testObjectWriteReadDelete() throws Exception { CustomType1 t = createCustomType1(1); - try { - db.read(CustomType1.class, t.key); - fail("Expected exception for non-existent object."); - } catch (NoSuchElementException nsee) { - // Expected. - } + assertThrows(NoSuchElementException.class, () -> db.read(CustomType1.class, t.key)); db.write(t); assertEquals(t, db.read(t.getClass(), t.key)); assertEquals(1L, db.count(t.getClass())); db.delete(t.getClass(), t.key); - try { - db.read(t.getClass(), t.key); - fail("Expected exception for deleted object."); - } catch (NoSuchElementException nsee) { - // Expected. - } + assertThrows(NoSuchElementException.class, () -> db.read(t.getClass(), t.key)); // Look into the actual DB and make sure that all the keys related to the type have been // removed. diff --git a/common/kvstore/src/test/java/org/apache/spark/util/kvstore/RocksDBSuite.java b/common/kvstore/src/test/java/org/apache/spark/util/kvstore/RocksDBSuite.java index 04463ee..8112cbf 100644 --- a/common/kvstore/src/test/java/org/apache/spark/util/kvstore/RocksDBSuite.java +++ b/common/kvstore/src/test/java/org/apache/spark/util/kvstore/RocksDBSuite.java @@ -69,36 +69,21 @@ public class RocksDBSuite { db.close(); db = null; - try { - db = new RocksDB(dbpath); - fail("Should have failed version check."); - } catch (UnsupportedStoreVersionException e) { - // Expected. - } + assertThrows(UnsupportedStoreVersionException.class, () -> db = new RocksDB(dbpath)); } @Test public void testObjectWriteReadDelete() throws Exception { CustomType1 t = createCustomType1(1); - try { - db.read(CustomType1.class, t.key); - fail("Expected exception for non-existent object."); - } catch (NoSuchElementException nsee) { - // Expected. - } + assertThrows(NoSuchElementException.class, () -> db.read(CustomType1.class, t.key)); db.write(t); assertEquals(t, db.read(t.getClass(), t.key)); assertEquals(1L, db.count(t.getClass())); db.delete(t.getClass(), t.key); - try { - db.read(t.getClass(), t.key); - fail("Expected exception for deleted object."); - } catch (NoSuchElementException nsee) { - // Expected. - } + assertThrows(NoSuchElementException.class, () -> db.read(t.getClass(), t.key)); // Look into the actual DB and make sure that all the keys related to the type have been // removed. diff --git a/common/network-common/src/test/java/org/apache/spark/network/crypto/AuthIntegrationSuite.java b/common/network-common/src/test/java/org/apache/spark/network/crypto/AuthIntegrationSuite.java index d4bf28e..62ccccb 100644 --- a/common/network-common/src/test/java/org/apache/spark/network/crypto/AuthIntegrationSuite.java +++ b/common/network-common/src/test/java/org/apache/spark/network/crypto/AuthIntegrationSuite.java @@ -72,13 +72,9 @@ public class AuthIntegrationSuite { ctx = new AuthTestCtx(); ctx.createServer("server"); - try { - ctx.createClient("client"); - fail("Should have failed to create client."); - } catch (Exception e) { - assertFalse(ctx.authRpcHandler.isAuthenticated()); - assertFalse(ctx.serverChannel.isActive()); - } + assertThrows(Exception.class, () -> ctx.createClient("client")); + assertFalse(ctx.authRpcHandler.isAuthenticated()); + assertFalse(ctx.serverChannel.isActive()); } @Test @@ -115,13 +111,9 @@ public class AuthIntegrationSuite { assertNotNull(ctx.client.getChannel().pipeline() .remove(TransportCipher.ENCRYPTION_HANDLER_NAME)); - - try { - ctx.client.sendRpcSync(JavaUtils.stringToBytes("Ping"), 5000); - fail("Should have failed unencrypted RPC."); - } catch (Exception e) { - assertTrue(ctx.authRpcHandler.isAuthenticated()); - } + assertThrows(Exception.class, + () -> ctx.client.sendRpcSync(JavaUtils.stringToBytes("Ping"), 5000)); + assertTrue(ctx.authRpcHandler.isAuthenticated()); } @Test @@ -147,17 +139,14 @@ public class AuthIntegrationSuite { ctx.createServer("secret"); ctx.createClient("secret"); - try { - ctx.client.sendRpcSync(JavaUtils.stringToBytes("Ping"), 5000); - fail("Should have failed unencrypted RPC."); - } catch (Exception e) { - assertTrue(ctx.authRpcHandler.isAuthenticated()); - assertTrue(e.getMessage() + " is not an expected error", e.getMessage().contains("DDDDD")); - // Verify we receive the complete error message - int messageStart = e.getMessage().indexOf("DDDDD"); - int messageEnd = e.getMessage().lastIndexOf("DDDDD") + 5; - assertEquals(testErrorMessageLength, messageEnd - messageStart); - } + Exception e = assertThrows(Exception.class, + () -> ctx.client.sendRpcSync(JavaUtils.stringToBytes("Ping"), 5000)); + assertTrue(ctx.authRpcHandler.isAuthenticated()); + assertTrue(e.getMessage() + " is not an expected error", e.getMessage().contains("DDDDD")); + // Verify we receive the complete error message + int messageStart = e.getMessage().indexOf("DDDDD"); + int messageEnd = e.getMessage().lastIndexOf("DDDDD") + 5; + assertEquals(testErrorMessageLength, messageEnd - messageStart); } private static class AuthTestCtx { diff --git a/common/network-common/src/test/java/org/apache/spark/network/crypto/TransportCipherSuite.java b/common/network-common/src/test/java/org/apache/spark/network/crypto/TransportCipherSuite.java index cff115d..cde5c1c 100644 --- a/common/network-common/src/test/java/org/apache/spark/network/crypto/TransportCipherSuite.java +++ b/common/network-common/src/test/java/org/apache/spark/network/crypto/TransportCipherSuite.java @@ -32,7 +32,7 @@ import org.junit.Test; import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertFalse; -import static org.junit.Assert.fail; +import static org.junit.Assert.assertThrows; import static org.mockito.ArgumentMatchers.any; import static org.mockito.ArgumentMatchers.anyInt; import static org.mockito.Mockito.mock; @@ -67,21 +67,12 @@ public class TransportCipherSuite { ByteBuf buffer = Unpooled.wrappedBuffer(new byte[] { 1, 2 }); ByteBuf buffer2 = Unpooled.wrappedBuffer(new byte[] { 1, 2 }); - try { - channel.writeInbound(buffer); - fail("Should have raised InternalError"); - } catch (InternalError expected) { - // expected - assertEquals(0, buffer.refCnt()); - } + assertThrows(InternalError.class, () -> channel.writeInbound(buffer)); + assertEquals(0, buffer.refCnt()); - try { - channel.writeInbound(buffer2); - fail("Should have raised an exception"); - } catch (Throwable expected) { - assertEquals(expected.getClass(), IOException.class); - assertEquals(0, buffer2.refCnt()); - } + Throwable expected = assertThrows(Throwable.class, () -> channel.writeInbound(buffer2)); + assertEquals(expected.getClass(), IOException.class); + assertEquals(0, buffer2.refCnt()); // Simulate closing the connection assertFalse(channel.finish()); diff --git a/common/network-common/src/test/java/org/apache/spark/network/sasl/SparkSaslSuite.java b/common/network-common/src/test/java/org/apache/spark/network/sasl/SparkSaslSuite.java index 2b0bcca..6096cd3 100644 --- a/common/network-common/src/test/java/org/apache/spark/network/sasl/SparkSaslSuite.java +++ b/common/network-common/src/test/java/org/apache/spark/network/sasl/SparkSaslSuite.java @@ -147,13 +147,11 @@ public class SparkSaslSuite { .when(rpcHandler) .receive(any(TransportClient.class), any(ByteBuffer.class), any(RpcResponseCallback.class)); - SaslTestCtx ctx = new SaslTestCtx(rpcHandler, encrypt, false); - try { + try (SaslTestCtx ctx = new SaslTestCtx(rpcHandler, encrypt, false)) { ByteBuffer response = ctx.client.sendRpcSync(JavaUtils.stringToBytes("Ping"), TimeUnit.SECONDS.toMillis(10)); assertEquals("Pong", JavaUtils.bytesToString(response)); } finally { - ctx.close(); // There should be 2 terminated events; one for the client, one for the server. Throwable error = null; long deadline = System.nanoTime() + TimeUnit.NANOSECONDS.convert(10, TimeUnit.SECONDS); @@ -301,19 +299,11 @@ public class SparkSaslSuite { } @Test - public void testServerAlwaysEncrypt() throws Exception { - SaslTestCtx ctx = null; - try { - ctx = new SaslTestCtx(mock(RpcHandler.class), false, false, - ImmutableMap.of("spark.network.sasl.serverAlwaysEncrypt", "true")); - fail("Should have failed to connect without encryption."); - } catch (Exception e) { - assertTrue(e.getCause() instanceof SaslException); - } finally { - if (ctx != null) { - ctx.close(); - } - } + public void testServerAlwaysEncrypt() { + Exception re = assertThrows(Exception.class, + () -> new SaslTestCtx(mock(RpcHandler.class), false, false, + ImmutableMap.of("spark.network.sasl.serverAlwaysEncrypt", "true"))); + assertTrue(re.getCause() instanceof SaslException); } @Test @@ -321,18 +311,11 @@ public class SparkSaslSuite { // This test sets up an encrypted connection but then, using a client bootstrap, removes // the encryption handler from the client side. This should cause the server to not be // able to understand RPCs sent to it and thus close the connection. - SaslTestCtx ctx = null; - try { - ctx = new SaslTestCtx(mock(RpcHandler.class), true, true); - ctx.client.sendRpcSync(JavaUtils.stringToBytes("Ping"), - TimeUnit.SECONDS.toMillis(10)); - fail("Should have failed to send RPC to server."); - } catch (Exception e) { + try (SaslTestCtx ctx = new SaslTestCtx(mock(RpcHandler.class), true, true)) { + Exception e = assertThrows(Exception.class, + () -> ctx.client.sendRpcSync(JavaUtils.stringToBytes("Ping"), + TimeUnit.SECONDS.toMillis(10))); assertFalse(e.getCause() instanceof TimeoutException); - } finally { - if (ctx != null) { - ctx.close(); - } } } @@ -362,7 +345,7 @@ public class SparkSaslSuite { } } - private static class SaslTestCtx { + private static class SaslTestCtx implements AutoCloseable { final TransportClient client; final TransportServer server; @@ -423,7 +406,7 @@ public class SparkSaslSuite { this.disableClientEncryption = disableClientEncryption; } - void close() { + public void close() { if (!disableClientEncryption) { assertEquals(encrypt, checker.foundEncryptionHandler); } diff --git a/common/network-common/src/test/java/org/apache/spark/network/server/OneForOneStreamManagerSuite.java b/common/network-common/src/test/java/org/apache/spark/network/server/OneForOneStreamManagerSuite.java index 634b40e..b65daaf 100644 --- a/common/network-common/src/test/java/org/apache/spark/network/server/OneForOneStreamManagerSuite.java +++ b/common/network-common/src/test/java/org/apache/spark/network/server/OneForOneStreamManagerSuite.java @@ -118,21 +118,12 @@ public class OneForOneStreamManagerSuite { Assert.assertEquals(2, manager.numStreamStates()); - try { - manager.connectionTerminated(dummyChannel); - Assert.fail("connectionTerminated should throw exception when fails to release all buffers"); - - } catch (RuntimeException e) { - - Mockito.verify(buffers, Mockito.times(1)).hasNext(); - Mockito.verify(buffers, Mockito.times(1)).next(); - - Mockito.verify(buffers2, Mockito.times(2)).hasNext(); - Mockito.verify(buffers2, Mockito.times(2)).next(); - - Mockito.verify(mockManagedBuffer, Mockito.times(1)).release(); - - Assert.assertEquals(0, manager.numStreamStates()); - } + Assert.assertThrows(RuntimeException.class, () -> manager.connectionTerminated(dummyChannel)); + Mockito.verify(buffers, Mockito.times(1)).hasNext(); + Mockito.verify(buffers, Mockito.times(1)).next(); + Mockito.verify(buffers2, Mockito.times(2)).hasNext(); + Mockito.verify(buffers2, Mockito.times(2)).next(); + Mockito.verify(mockManagedBuffer, Mockito.times(1)).release(); + Assert.assertEquals(0, manager.numStreamStates()); } } diff --git a/common/network-shuffle/src/test/java/org/apache/spark/network/sasl/SaslIntegrationSuite.java b/common/network-shuffle/src/test/java/org/apache/spark/network/sasl/SaslIntegrationSuite.java index 96dfc3b..ec749cb 100644 --- a/common/network-shuffle/src/test/java/org/apache/spark/network/sasl/SaslIntegrationSuite.java +++ b/common/network-shuffle/src/test/java/org/apache/spark/network/sasl/SaslIntegrationSuite.java @@ -108,13 +108,10 @@ public class SaslIntegrationSuite { clientFactory = context.createClientFactory( Arrays.asList(new SaslClientBootstrap(conf, "unknown-app", badKeyHolder))); - try { - // Bootstrap should fail on startup. - clientFactory.createClient(TestUtils.getLocalHost(), server.getPort()); - fail("Connection should have failed."); - } catch (Exception e) { - assertTrue(e.getMessage(), e.getMessage().contains("Mismatched response")); - } + // Bootstrap should fail on startup. + Exception e = assertThrows(Exception.class, + () -> clientFactory.createClient(TestUtils.getLocalHost(), server.getPort())); + assertTrue(e.getMessage(), e.getMessage().contains("Mismatched response")); } @Test @@ -122,20 +119,14 @@ public class SaslIntegrationSuite { clientFactory = context.createClientFactory(new ArrayList<>()); TransportClient client = clientFactory.createClient(TestUtils.getLocalHost(), server.getPort()); - try { - client.sendRpcSync(ByteBuffer.allocate(13), TIMEOUT_MS); - fail("Should have failed"); - } catch (Exception e) { - assertTrue(e.getMessage(), e.getMessage().contains("Expected SaslMessage")); - } - - try { - // Guessing the right tag byte doesn't magically get you in... - client.sendRpcSync(ByteBuffer.wrap(new byte[] { (byte) 0xEA }), TIMEOUT_MS); - fail("Should have failed"); - } catch (Exception e) { - assertTrue(e.getMessage(), e.getMessage().contains("java.lang.IndexOutOfBoundsException")); - } + Exception e1 = assertThrows(Exception.class, + () -> client.sendRpcSync(ByteBuffer.allocate(13), TIMEOUT_MS)); + assertTrue(e1.getMessage(), e1.getMessage().contains("Expected SaslMessage")); + + // Guessing the right tag byte doesn't magically get you in... + Exception e2 = assertThrows(Exception.class, + () -> client.sendRpcSync(ByteBuffer.wrap(new byte[] { (byte) 0xEA }), TIMEOUT_MS)); + assertTrue(e2.getMessage(), e2.getMessage().contains("java.lang.IndexOutOfBoundsException")); } @Test @@ -145,8 +136,8 @@ public class SaslIntegrationSuite { clientFactory = context.createClientFactory( Arrays.asList(new SaslClientBootstrap(conf, "app-1", secretKeyHolder))); try (TransportServer server = context.createServer()) { - clientFactory.createClient(TestUtils.getLocalHost(), server.getPort()); - } catch (Exception e) { + Exception e = assertThrows(Exception.class, + () -> clientFactory.createClient(TestUtils.getLocalHost(), server.getPort())); assertTrue(e.getMessage(), e.getMessage().contains("Digest-challenge format violation")); } } diff --git a/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalBlockHandlerSuite.java b/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalBlockHandlerSuite.java index d45cbd5..14896c8 100644 --- a/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalBlockHandlerSuite.java +++ b/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalBlockHandlerSuite.java @@ -332,21 +332,11 @@ public class ExternalBlockHandlerSuite { RpcResponseCallback callback = mock(RpcResponseCallback.class); ByteBuffer unserializableMsg = ByteBuffer.wrap(new byte[] { 0x12, 0x34, 0x56 }); - try { - handler.receive(client, unserializableMsg, callback); - fail("Should have thrown"); - } catch (Exception e) { - // pass - } + assertThrows(Exception.class, () -> handler.receive(client, unserializableMsg, callback)); ByteBuffer unexpectedMsg = new UploadBlock("a", "e", "b", new byte[1], new byte[2]).toByteBuffer(); - try { - handler.receive(client, unexpectedMsg, callback); - fail("Should have thrown"); - } catch (UnsupportedOperationException e) { - // pass - } + assertThrows(Exception.class, () -> handler.receive(client, unexpectedMsg, callback)); verify(callback, never()).onSuccess(any(ByteBuffer.class)); verify(callback, never()).onFailure(any(Throwable.class)); diff --git a/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleBlockResolverSuite.java b/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleBlockResolverSuite.java index 04d4bdf..ec195e8 100644 --- a/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleBlockResolverSuite.java +++ b/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleBlockResolverSuite.java @@ -64,22 +64,15 @@ public class ExternalShuffleBlockResolverSuite { public void testBadRequests() throws IOException { ExternalShuffleBlockResolver resolver = new ExternalShuffleBlockResolver(conf, null); // Unregistered executor - try { - resolver.getBlockData("app0", "exec1", 1, 1, 0); - fail("Should have failed"); - } catch (RuntimeException e) { - assertTrue("Bad error message: " + e, e.getMessage().contains("not registered")); - } + RuntimeException e = assertThrows(RuntimeException.class, + () -> resolver.getBlockData("app0", "exec1", 1, 1, 0)); + assertTrue("Bad error message: " + e, e.getMessage().contains("not registered")); // Nonexistent shuffle block resolver.registerExecutor("app0", "exec3", dataContext.createExecutorInfo(SORT_MANAGER)); - try { - resolver.getBlockData("app0", "exec3", 1, 1, 0); - fail("Should have failed"); - } catch (Exception e) { - // pass - } + assertThrows(Exception.class, + () -> resolver.getBlockData("app0", "exec3", 1, 1, 0)); } @Test diff --git a/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleSecuritySuite.java b/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleSecuritySuite.java index 883e643..c52ac31 100644 --- a/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleSecuritySuite.java +++ b/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleSecuritySuite.java @@ -70,20 +70,16 @@ public class ExternalShuffleSecuritySuite { @Test public void testBadAppId() { - try { - validate("wrong-app-id", "secret", false); - } catch (Exception e) { - assertTrue(e.getMessage(), e.getMessage().contains("Wrong appId!")); - } + Exception e = assertThrows(Exception.class, + () -> validate("wrong-app-id", "secret", false)); + assertTrue(e.getMessage(), e.getMessage().contains("Wrong appId!")); } @Test public void testBadSecret() { - try { - validate("my-app-id", "bad-secret", false); - } catch (Exception e) { - assertTrue(e.getMessage(), e.getMessage().contains("Mismatched response")); - } + Exception e = assertThrows(Exception.class, + () -> validate("my-app-id", "bad-secret", false)); + assertTrue(e.getMessage(), e.getMessage().contains("Mismatched response")); } @Test diff --git a/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/OneForOneBlockFetcherSuite.java b/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/OneForOneBlockFetcherSuite.java index cc4640d..5f3d3c8 100644 --- a/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/OneForOneBlockFetcherSuite.java +++ b/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/OneForOneBlockFetcherSuite.java @@ -189,16 +189,10 @@ public class OneForOneBlockFetcherSuite { @Test public void testEmptyBlockFetch() { - try { - fetchBlocks( - Maps.newLinkedHashMap(), - new String[] {}, - new OpenBlocks("app-id", "exec-id", new String[] {}), - conf); - fail(); - } catch (IllegalArgumentException e) { - assertEquals("Zero-sized blockIds array", e.getMessage()); - } + IllegalArgumentException e = assertThrows(IllegalArgumentException.class, + () -> fetchBlocks(Maps.newLinkedHashMap(), new String[] {}, + new OpenBlocks("app-id", "exec-id", new String[] {}), conf)); + assertEquals("Zero-sized blockIds array", e.getMessage()); } @Test diff --git a/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/RemoteBlockPushResolverSuite.java b/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/RemoteBlockPushResolverSuite.java index f76afae..20aae7c 100644 --- a/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/RemoteBlockPushResolverSuite.java +++ b/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/RemoteBlockPushResolverSuite.java @@ -32,7 +32,6 @@ import java.util.Map; import java.util.concurrent.Semaphore; import java.util.concurrent.ThreadLocalRandom; -import com.google.common.base.Throwables; import com.google.common.collect.ImmutableMap; import org.apache.commons.io.FileUtils; @@ -120,14 +119,11 @@ public class RemoteBlockPushResolverSuite { assertTrue(errorHandler.shouldLogError(new Throwable())); } - @Test(expected = RuntimeException.class) + @Test public void testNoIndexFile() { - try { - pushResolver.getMergedBlockMeta(TEST_APP, 0, 0, 0); - } catch (Throwable t) { - assertTrue(t.getMessage().startsWith("Merged shuffle index file")); - Throwables.propagate(t); - } + RuntimeException re = assertThrows(RuntimeException.class, + () -> pushResolver.getMergedBlockMeta(TEST_APP, 0, 0, 0)); + assertTrue(re.getMessage().startsWith("Merged shuffle index file")); } @Test @@ -303,7 +299,7 @@ public class RemoteBlockPushResolverSuite { validateChunks(TEST_APP, 0, 0, 0, blockMeta, new int[]{9}, new int[][]{{0}}); } - @Test(expected = BlockPushNonFatalFailure.class) + @Test public void testBlockReceivedAfterMergeFinalize() throws IOException { ByteBuffer[] blocks = new ByteBuffer[]{ ByteBuffer.wrap(new byte[4]), @@ -319,18 +315,15 @@ public class RemoteBlockPushResolverSuite { StreamCallbackWithID stream1 = pushResolver.receiveBlockDataAsStream( new PushBlockStream(TEST_APP, NO_ATTEMPT_ID, 0, 0, 1, 0, 0)); stream1.onData(stream1.getID(), ByteBuffer.wrap(new byte[4])); - try { - stream1.onComplete(stream1.getID()); - } catch (BlockPushNonFatalFailure e) { - BlockPushReturnCode errorCode = - (BlockPushReturnCode) BlockTransferMessage.Decoder.fromByteBuffer(e.getResponse()); - assertEquals(BlockPushNonFatalFailure.ReturnCode.TOO_LATE_BLOCK_PUSH.id(), - errorCode.returnCode); - assertEquals(errorCode.failureBlockId, stream1.getID()); - MergedBlockMeta blockMeta = pushResolver.getMergedBlockMeta(TEST_APP, 0, 0, 0); - validateChunks(TEST_APP, 0, 0, 0, blockMeta, new int[]{9}, new int[][]{{0}}); - throw e; - } + BlockPushNonFatalFailure e = assertThrows(BlockPushNonFatalFailure.class, + () -> stream1.onComplete(stream1.getID())); + BlockPushReturnCode errorCode = + (BlockPushReturnCode) BlockTransferMessage.Decoder.fromByteBuffer(e.getResponse()); + assertEquals(BlockPushNonFatalFailure.ReturnCode.TOO_LATE_BLOCK_PUSH.id(), + errorCode.returnCode); + assertEquals(errorCode.failureBlockId, stream1.getID()); + MergedBlockMeta blockMeta = pushResolver.getMergedBlockMeta(TEST_APP, 0, 0, 0); + validateChunks(TEST_APP, 0, 0, 0, blockMeta, new int[]{9}, new int[][]{{0}}); } @Test @@ -367,7 +360,7 @@ public class RemoteBlockPushResolverSuite { assertArrayEquals(expectedBytes, mb.nioByteBuffer().array()); } - @Test(expected = BlockPushNonFatalFailure.class) + @Test public void testCollision() throws IOException { StreamCallbackWithID stream1 = pushResolver.receiveBlockDataAsStream( @@ -379,19 +372,16 @@ public class RemoteBlockPushResolverSuite { // This should be deferred stream2.onData(stream2.getID(), ByteBuffer.wrap(new byte[5])); // Since stream2 didn't get any opportunity it will throw couldn't find opportunity error - try { - stream2.onComplete(stream2.getID()); - } catch (BlockPushNonFatalFailure e) { - BlockPushReturnCode errorCode = - (BlockPushReturnCode) BlockTransferMessage.Decoder.fromByteBuffer(e.getResponse()); - assertEquals(BlockPushNonFatalFailure.ReturnCode.BLOCK_APPEND_COLLISION_DETECTED.id(), - errorCode.returnCode); - assertEquals(errorCode.failureBlockId, stream2.getID()); - throw e; - } + BlockPushNonFatalFailure e = assertThrows(BlockPushNonFatalFailure.class, + () -> stream2.onComplete(stream2.getID())); + BlockPushReturnCode errorCode = + (BlockPushReturnCode) BlockTransferMessage.Decoder.fromByteBuffer(e.getResponse()); + assertEquals(BlockPushNonFatalFailure.ReturnCode.BLOCK_APPEND_COLLISION_DETECTED.id(), + errorCode.returnCode); + assertEquals(errorCode.failureBlockId, stream2.getID()); } - @Test(expected = BlockPushNonFatalFailure.class) + @Test public void testFailureInAStreamDoesNotInterfereWithStreamWhichIsWriting() throws IOException { StreamCallbackWithID stream1 = pushResolver.receiveBlockDataAsStream( @@ -408,17 +398,13 @@ public class RemoteBlockPushResolverSuite { // This should be deferred stream3.onData(stream3.getID(), ByteBuffer.wrap(new byte[5])); // Since this stream didn't get any opportunity it will throw couldn't find opportunity error - BlockPushNonFatalFailure failedEx = null; - try { - stream3.onComplete(stream3.getID()); - } catch (BlockPushNonFatalFailure e) { - BlockPushReturnCode errorCode = - (BlockPushReturnCode) BlockTransferMessage.Decoder.fromByteBuffer(e.getResponse()); - assertEquals(BlockPushNonFatalFailure.ReturnCode.BLOCK_APPEND_COLLISION_DETECTED.id(), - errorCode.returnCode); - assertEquals(errorCode.failureBlockId, stream3.getID()); - failedEx = e; - } + BlockPushNonFatalFailure e = assertThrows(BlockPushNonFatalFailure.class, + () -> stream3.onComplete(stream3.getID())); + BlockPushReturnCode errorCode = + (BlockPushReturnCode) BlockTransferMessage.Decoder.fromByteBuffer(e.getResponse()); + assertEquals(BlockPushNonFatalFailure.ReturnCode.BLOCK_APPEND_COLLISION_DETECTED.id(), + errorCode.returnCode); + assertEquals(errorCode.failureBlockId, stream3.getID()); // stream 1 now completes stream1.onData(stream1.getID(), ByteBuffer.wrap(new byte[2])); stream1.onComplete(stream1.getID()); @@ -426,12 +412,9 @@ public class RemoteBlockPushResolverSuite { pushResolver.finalizeShuffleMerge(new FinalizeShuffleMerge(TEST_APP, NO_ATTEMPT_ID, 0, 0)); MergedBlockMeta blockMeta = pushResolver.getMergedBlockMeta(TEST_APP, 0, 0, 0); validateChunks(TEST_APP, 0, 0, 0, blockMeta, new int[] {4}, new int[][] {{0}}); - if (failedEx != null) { - throw failedEx; - } } - @Test(expected = IllegalArgumentException.class) + @Test public void testUpdateLocalDirsOnlyOnce() throws IOException { String testApp = "updateLocalDirsOnlyOnceTest"; Path[] activeLocalDirs = createLocalDirs(1); @@ -449,32 +432,25 @@ public class RemoteBlockPushResolverSuite { assertTrue(pushResolver.getMergedBlockDirs(testApp)[0].contains( activeLocalDirs[0].toFile().getPath())); removeApplication(testApp); - try { - pushResolver.getMergedBlockDirs(testApp); - } catch (IllegalArgumentException e) { - assertEquals(e.getMessage(), - "application " + testApp + " is not registered or NM was restarted."); - throw e; - } + IllegalArgumentException e = assertThrows(IllegalArgumentException.class, + () -> pushResolver.getMergedBlockDirs(testApp)); + assertEquals(e.getMessage(), + "application " + testApp + " is not registered or NM was restarted."); } - @Test(expected = IllegalArgumentException.class) + @Test public void testExecutorRegisterWithInvalidJsonForPushShuffle() throws IOException { String testApp = "executorRegisterWithInvalidShuffleManagerMeta"; Path[] activeLocalDirs = createLocalDirs(1); - try { - registerExecutor(testApp, prepareLocalDirs(activeLocalDirs, MERGE_DIRECTORY), - INVALID_MERGE_DIRECTORY_META); - } catch (IllegalArgumentException re) { - assertEquals( - "Failed to get the merge directory information from the shuffleManagerMeta " + - "shuffleManager:{\"mergeDirInvalid\": \"merge_manager_2\", \"attemptId\": \"2\"} in " + - "executor registration message", re.getMessage()); - throw re; - } + IllegalArgumentException re = assertThrows(IllegalArgumentException.class, + () -> registerExecutor(testApp, prepareLocalDirs(activeLocalDirs, MERGE_DIRECTORY), + INVALID_MERGE_DIRECTORY_META)); + assertEquals("Failed to get the merge directory information from the shuffleManagerMeta " + + "shuffleManager:{\"mergeDirInvalid\": \"merge_manager_2\", \"attemptId\": \"2\"} in " + + "executor registration message", re.getMessage()); } - @Test(expected = IllegalArgumentException.class) + @Test public void testExecutorRegistrationFromTwoAppAttempts() throws IOException { String testApp = "testExecutorRegistrationFromTwoAppAttempts"; Path[] attempt1LocalDirs = createLocalDirs(1); @@ -502,13 +478,10 @@ public class RemoteBlockPushResolverSuite { assertTrue(pushResolver.getMergedBlockDirs(testApp)[0].contains( attempt2LocalDirs[0].toFile().getPath())); removeApplication(testApp); - try { - pushResolver.getMergedBlockDirs(testApp); - } catch (IllegalArgumentException e) { - assertEquals(e.getMessage(), - "application " + testApp + " is not registered or NM was restarted."); - throw e; - } + IllegalArgumentException e = assertThrows(IllegalArgumentException.class, + () -> pushResolver.getMergedBlockDirs(testApp)); + assertEquals(e.getMessage(), + "application " + testApp + " is not registered or NM was restarted."); } @Test @@ -673,7 +646,7 @@ public class RemoteBlockPushResolverSuite { validateChunks(TEST_APP, 0, 0, 0, blockMeta, new int[] {4, 5}, new int[][] {{0}, {1}}); } - @Test(expected = IllegalStateException.class) + @Test public void testIOExceptionsExceededThreshold() throws IOException { RemoteBlockPushResolver.PushBlockStreamCallback callback = (RemoteBlockPushResolver.PushBlockStreamCallback) pushResolver.receiveBlockDataAsStream( @@ -687,28 +660,23 @@ public class RemoteBlockPushResolverSuite { RemoteBlockPushResolver.PushBlockStreamCallback callback1 = (RemoteBlockPushResolver.PushBlockStreamCallback) pushResolver.receiveBlockDataAsStream( new PushBlockStream(TEST_APP, NO_ATTEMPT_ID, 0, 0, i, 0, 0)); - try { - callback1.onData(callback1.getID(), ByteBuffer.wrap(new byte[2])); - } catch (IOException ioe) { - // this will throw IOException so the client can retry. - callback1.onFailure(callback1.getID(), ioe); - } + IOException ioe = assertThrows(IOException.class, + () -> callback1.onData(callback1.getID(), ByteBuffer.wrap(new byte[2]))); + // this will throw IOException so the client can retry. + callback1.onFailure(callback1.getID(), ioe); } assertEquals(4, partitionInfo.getNumIOExceptions()); // After 4 IOException, the server will respond with IOExceptions exceeded threshold - try { - RemoteBlockPushResolver.PushBlockStreamCallback callback2 = - (RemoteBlockPushResolver.PushBlockStreamCallback) pushResolver.receiveBlockDataAsStream( - new PushBlockStream(TEST_APP, NO_ATTEMPT_ID, 0, 0, 5, 0, 0)); - callback2.onData(callback.getID(), ByteBuffer.wrap(new byte[1])); - } catch (Throwable t) { - assertEquals("IOExceptions exceeded the threshold when merging shufflePush_0_0_5_0", - t.getMessage()); - throw t; - } + RemoteBlockPushResolver.PushBlockStreamCallback callback2 = + (RemoteBlockPushResolver.PushBlockStreamCallback) pushResolver.receiveBlockDataAsStream( + new PushBlockStream(TEST_APP, NO_ATTEMPT_ID, 0, 0, 5, 0, 0)); + IllegalStateException e = assertThrows(IllegalStateException.class, + () -> callback2.onData(callback.getID(), ByteBuffer.wrap(new byte[1]))); + assertEquals("IOExceptions exceeded the threshold when merging shufflePush_0_0_5_0", + e.getMessage()); } - @Test(expected = IllegalStateException.class) + @Test public void testIOExceptionsDuringMetaUpdateIncreasesExceptionCount() throws IOException { useTestFiles(true, false); RemoteBlockPushResolver.PushBlockStreamCallback callback = @@ -730,37 +698,29 @@ public class RemoteBlockPushResolverSuite { assertEquals(4, partitionInfo.getNumIOExceptions()); // After 4 IOException, the server will respond with IOExceptions exceeded threshold for any // new request for this partition. - try { - RemoteBlockPushResolver.PushBlockStreamCallback callback2 = + RemoteBlockPushResolver.PushBlockStreamCallback callback2 = (RemoteBlockPushResolver.PushBlockStreamCallback) pushResolver.receiveBlockDataAsStream( new PushBlockStream(TEST_APP, NO_ATTEMPT_ID, 0, 0, 5, 0, 0)); - callback2.onData(callback2.getID(), ByteBuffer.wrap(new byte[4])); - callback2.onComplete(callback2.getID()); - } catch (Throwable t) { - assertEquals("IOExceptions exceeded the threshold when merging shufflePush_0_0_5_0", - t.getMessage()); - throw t; - } + callback2.onData(callback2.getID(), ByteBuffer.wrap(new byte[4])); + IllegalStateException e = assertThrows(IllegalStateException.class, + () -> callback2.onComplete(callback2.getID())); + assertEquals("IOExceptions exceeded the threshold when merging shufflePush_0_0_5_0", + e.getMessage()); } - @Test(expected = IllegalStateException.class) - public void testRequestForAbortedShufflePartitionThrowsException() { - try { - testIOExceptionsDuringMetaUpdateIncreasesExceptionCount(); - } catch (Throwable t) { - // No more blocks can be merged to this partition. - } - try { - pushResolver.receiveBlockDataAsStream( - new PushBlockStream(TEST_APP, NO_ATTEMPT_ID, 0, 0, 10, 0, 0)); - } catch (Throwable t) { - assertEquals("IOExceptions exceeded the threshold when merging shufflePush_0_0_10_0", - t.getMessage()); - throw t; - } + @Test + public void testRequestForAbortedShufflePartitionThrowsException() throws IOException { + // No more blocks can be merged to this partition. + testIOExceptionsDuringMetaUpdateIncreasesExceptionCount(); + + IllegalStateException t = assertThrows(IllegalStateException.class, + () -> pushResolver.receiveBlockDataAsStream( + new PushBlockStream(TEST_APP, NO_ATTEMPT_ID, 0, 0, 10, 0, 0))); + assertEquals("IOExceptions exceeded the threshold when merging shufflePush_0_0_10_0", + t.getMessage()); } - @Test(expected = IllegalStateException.class) + @Test public void testPendingBlockIsAbortedImmediately() throws IOException { useTestFiles(true, false); RemoteBlockPushResolver.PushBlockStreamCallback callback = @@ -773,27 +733,25 @@ public class RemoteBlockPushResolverSuite { RemoteBlockPushResolver.PushBlockStreamCallback callback1 = (RemoteBlockPushResolver.PushBlockStreamCallback) pushResolver.receiveBlockDataAsStream( new PushBlockStream(TEST_APP, NO_ATTEMPT_ID, 0, 0, i, 0, 0)); - try { - callback1.onData(callback1.getID(), ByteBuffer.wrap(new byte[5])); + callback1.onData(callback1.getID(), ByteBuffer.wrap(new byte[5])); + if (i < 5) { // This will complete without any exceptions but the exception count is increased. callback1.onComplete(callback1.getID()); - } catch (Throwable t) { + } else { + Throwable t = assertThrows(Throwable.class, () -> callback1.onComplete(callback1.getID())); callback1.onFailure(callback1.getID(), t); } } assertEquals(5, partitionInfo.getNumIOExceptions()); // The server will respond with IOExceptions exceeded threshold for any additional attempts // to write. - try { - callback.onData(callback.getID(), ByteBuffer.wrap(new byte[4])); - } catch (Throwable t) { - assertEquals("IOExceptions exceeded the threshold when merging shufflePush_0_0_0_0", - t.getMessage()); - throw t; - } + IllegalStateException e = assertThrows(IllegalStateException.class, + () -> callback.onData(callback.getID(), ByteBuffer.wrap(new byte[4]))); + assertEquals("IOExceptions exceeded the threshold when merging shufflePush_0_0_0_0", + e.getMessage()); } - @Test(expected = IllegalStateException.class) + @Test public void testWritingPendingBufsIsAbortedImmediatelyDuringComplete() throws IOException { useTestFiles(true, false); RemoteBlockPushResolver.PushBlockStreamCallback callback = @@ -806,13 +764,9 @@ public class RemoteBlockPushResolverSuite { RemoteBlockPushResolver.PushBlockStreamCallback callback1 = (RemoteBlockPushResolver.PushBlockStreamCallback) pushResolver.receiveBlockDataAsStream( new PushBlockStream(TEST_APP, NO_ATTEMPT_ID, 0, 0, i, 0, 0)); - try { - callback1.onData(callback1.getID(), ByteBuffer.wrap(new byte[5])); - // This will complete without any exceptions but the exception count is increased. - callback1.onComplete(callback1.getID()); - } catch (Throwable t) { - callback1.onFailure(callback1.getID(), t); - } + callback1.onData(callback1.getID(), ByteBuffer.wrap(new byte[5])); + // This will complete without any exceptions but the exception count is increased. + callback1.onComplete(callback1.getID()); } assertEquals(4, partitionInfo.getNumIOExceptions()); RemoteBlockPushResolver.PushBlockStreamCallback callback2 = @@ -822,22 +776,16 @@ public class RemoteBlockPushResolverSuite { // This is deferred callback.onData(callback.getID(), ByteBuffer.wrap(new byte[4])); // Callback2 completes which will throw another exception. - try { - callback2.onComplete(callback2.getID()); - } catch (Throwable t) { - callback2.onFailure(callback2.getID(), t); - } + Throwable t = assertThrows(Throwable.class, () -> callback2.onComplete(callback2.getID())); + callback2.onFailure(callback2.getID(), t); assertEquals(5, partitionInfo.getNumIOExceptions()); // Restore index file so that any further writes to it are successful and any exceptions are // due to IOExceptions exceeding threshold. testIndexFile.restore(); - try { - callback.onComplete(callback.getID()); - } catch (Throwable t) { - assertEquals("IOExceptions exceeded the threshold when merging shufflePush_0_0_0_0", - t.getMessage()); - throw t; - } + IllegalStateException ie = assertThrows(IllegalStateException.class, + () -> callback.onComplete(callback.getID())); + assertEquals("IOExceptions exceeded the threshold when merging shufflePush_0_0_0_0", + ie.getMessage()); } @Test @@ -894,7 +842,7 @@ public class RemoteBlockPushResolverSuite { removeApplication(TEST_APP); } - @Test(expected = BlockPushNonFatalFailure.class) + @Test public void testFailureAfterDuplicateBlockDoesNotInterfereActiveStream() throws IOException { StreamCallbackWithID stream1 = pushResolver.receiveBlockDataAsStream( @@ -918,17 +866,13 @@ public class RemoteBlockPushResolverSuite { new PushBlockStream(TEST_APP, NO_ATTEMPT_ID, 0, 0, 2, 0, 0)); // This should be deferred as stream 2 is still the active stream stream3.onData(stream3.getID(), ByteBuffer.wrap(new byte[2])); - BlockPushNonFatalFailure failedEx = null; - try { - stream3.onComplete(stream3.getID()); - } catch (BlockPushNonFatalFailure e) { - BlockPushReturnCode errorCode = - (BlockPushReturnCode) BlockTransferMessage.Decoder.fromByteBuffer(e.getResponse()); - assertEquals(BlockPushNonFatalFailure.ReturnCode.BLOCK_APPEND_COLLISION_DETECTED.id(), - errorCode.returnCode); - assertEquals(errorCode.failureBlockId, stream3.getID()); - failedEx = e; - } + BlockPushNonFatalFailure e = assertThrows(BlockPushNonFatalFailure.class, + () -> stream3.onComplete(stream3.getID())); + BlockPushReturnCode errorCode = + (BlockPushReturnCode) BlockTransferMessage.Decoder.fromByteBuffer(e.getResponse()); + assertEquals(BlockPushNonFatalFailure.ReturnCode.BLOCK_APPEND_COLLISION_DETECTED.id(), + errorCode.returnCode); + assertEquals(errorCode.failureBlockId, stream3.getID()); // Stream 2 writes more and completes stream2.onData(stream2.getID(), ByteBuffer.wrap(new byte[4])); stream2.onComplete(stream2.getID()); @@ -936,12 +880,9 @@ public class RemoteBlockPushResolverSuite { MergedBlockMeta blockMeta = pushResolver.getMergedBlockMeta(TEST_APP, 0, 0, 0); validateChunks(TEST_APP, 0, 0, 0, blockMeta, new int[] {11}, new int[][] {{0, 1}}); removeApplication(TEST_APP); - if (failedEx != null) { - throw failedEx; - } } - @Test(expected = BlockPushNonFatalFailure.class) + @Test public void testPushBlockFromPreviousAttemptIsRejected() throws IOException, InterruptedException { Semaphore closed = new Semaphore(0); @@ -997,22 +938,19 @@ public class RemoteBlockPushResolverSuite { assertFalse(partitionInfo.getMetaFile().getChannel().isOpen()); assertFalse(partitionInfo.getIndexFile().getChannel().isOpen()); } - try { - pushResolver.receiveBlockDataAsStream( - new PushBlockStream(testApp, 1, 0, 0, 1, 0, 0)); - } catch (BlockPushNonFatalFailure re) { - BlockPushReturnCode errorCode = - (BlockPushReturnCode) BlockTransferMessage.Decoder.fromByteBuffer(re.getResponse()); - assertEquals(BlockPushNonFatalFailure.ReturnCode.TOO_OLD_ATTEMPT_PUSH.id(), - errorCode.returnCode); - assertEquals(errorCode.failureBlockId, stream2.getID()); - throw re; - } + BlockPushNonFatalFailure re = assertThrows(BlockPushNonFatalFailure.class, + () -> pushResolver.receiveBlockDataAsStream( + new PushBlockStream(testApp, 1, 0, 0, 1, 0, 0))); + BlockPushReturnCode errorCode = + (BlockPushReturnCode) BlockTransferMessage.Decoder.fromByteBuffer(re.getResponse()); + assertEquals(BlockPushNonFatalFailure.ReturnCode.TOO_OLD_ATTEMPT_PUSH.id(), + errorCode.returnCode); + assertEquals(errorCode.failureBlockId, stream2.getID()); } - @Test(expected = IllegalArgumentException.class) + @Test public void testFinalizeShuffleMergeFromPreviousAttemptIsAborted() - throws IOException, InterruptedException { + throws IOException { String testApp = "testFinalizeShuffleMergeFromPreviousAttemptIsAborted"; Path[] attempt1LocalDirs = createLocalDirs(1); registerExecutor(testApp, @@ -1032,15 +970,13 @@ public class RemoteBlockPushResolverSuite { registerExecutor(testApp, prepareLocalDirs(attempt2LocalDirs, MERGE_DIRECTORY + "_" + ATTEMPT_ID_2), MERGE_DIRECTORY_META_2); - try { - pushResolver.finalizeShuffleMerge(new FinalizeShuffleMerge(testApp, ATTEMPT_ID_1, 0, 0)); - } catch (IllegalArgumentException e) { - assertEquals(e.getMessage(), - String.format("The attempt id %s in this FinalizeShuffleMerge message does not " + - "match with the current attempt id %s stored in shuffle service for application %s", - ATTEMPT_ID_1, ATTEMPT_ID_2, testApp)); - throw e; - } + IllegalArgumentException e = assertThrows(IllegalArgumentException.class, + () -> pushResolver.finalizeShuffleMerge( + new FinalizeShuffleMerge(testApp, ATTEMPT_ID_1, 0, 0))); + assertEquals(e.getMessage(), + String.format("The attempt id %s in this FinalizeShuffleMerge message does not " + + "match with the current attempt id %s stored in shuffle service for application %s", + ATTEMPT_ID_1, ATTEMPT_ID_2, testApp)); } @Test(expected = ClosedChannelException.class) @@ -1095,16 +1031,13 @@ public class RemoteBlockPushResolverSuite { stream2.onData(stream2.getID(), ByteBuffer.wrap(new byte[2])); stream2.onData(stream2.getID(), ByteBuffer.wrap(new byte[2])); stream1.onData(stream1.getID(), ByteBuffer.wrap(new byte[2])); - try { - // stream 1 push should be rejected as it is from an older shuffleMergeId - stream1.onComplete(stream1.getID()); - } catch (BlockPushNonFatalFailure e) { - BlockPushReturnCode errorCode = - (BlockPushReturnCode) BlockTransferMessage.Decoder.fromByteBuffer(e.getResponse()); - assertEquals(BlockPushNonFatalFailure.ReturnCode.STALE_BLOCK_PUSH.id(), - errorCode.returnCode); - assertEquals(errorCode.failureBlockId, stream1.getID()); - } + BlockPushNonFatalFailure e = assertThrows(BlockPushNonFatalFailure.class, + () -> stream1.onComplete(stream1.getID())); + BlockPushReturnCode errorCode = + (BlockPushReturnCode) BlockTransferMessage.Decoder.fromByteBuffer(e.getResponse()); + assertEquals(BlockPushNonFatalFailure.ReturnCode.STALE_BLOCK_PUSH.id(), + errorCode.returnCode); + assertEquals(errorCode.failureBlockId, stream1.getID()); // stream 2 now completes stream2.onComplete(stream2.getID()); pushResolver.finalizeShuffleMerge(new FinalizeShuffleMerge(TEST_APP, NO_ATTEMPT_ID, 0, 2)); @@ -1124,25 +1057,22 @@ public class RemoteBlockPushResolverSuite { stream2.onData(stream2.getID(), ByteBuffer.wrap(new byte[2])); stream2.onData(stream2.getID(), ByteBuffer.wrap(new byte[2])); stream1.onData(stream1.getID(), ByteBuffer.wrap(new byte[2])); - try { - // stream 1 push should be rejected as it is from an older shuffleMergeId - stream1.onComplete(stream1.getID()); - } catch (BlockPushNonFatalFailure e) { - BlockPushReturnCode errorCode = - (BlockPushReturnCode) BlockTransferMessage.Decoder.fromByteBuffer(e.getResponse()); - assertEquals(BlockPushNonFatalFailure.ReturnCode.STALE_BLOCK_PUSH.id(), - errorCode.returnCode); - assertEquals(errorCode.failureBlockId, stream1.getID()); - } + // stream 1 push should be rejected as it is from an older shuffleMergeId + BlockPushNonFatalFailure e = assertThrows(BlockPushNonFatalFailure.class, + () -> stream1.onComplete(stream1.getID())); + BlockPushReturnCode errorCode = + (BlockPushReturnCode) BlockTransferMessage.Decoder.fromByteBuffer(e.getResponse()); + assertEquals(BlockPushNonFatalFailure.ReturnCode.STALE_BLOCK_PUSH.id(), + errorCode.returnCode); + assertEquals(errorCode.failureBlockId, stream1.getID()); // stream 2 now completes stream2.onComplete(stream2.getID()); - try { - pushResolver.finalizeShuffleMerge(new FinalizeShuffleMerge(TEST_APP, NO_ATTEMPT_ID, 0, 1)); - } catch(RuntimeException re) { - assertEquals("Shuffle merge finalize request for shuffle 0 with shuffleMergeId 1 is stale" - + " shuffle finalize request as shuffle blocks of a higher shuffleMergeId for the shuffle" - + " is already being pushed", re.getMessage()); - } + RuntimeException re = assertThrows(RuntimeException.class, + () -> pushResolver.finalizeShuffleMerge( + new FinalizeShuffleMerge(TEST_APP, NO_ATTEMPT_ID, 0, 1))); + assertEquals("Shuffle merge finalize request for shuffle 0 with shuffleMergeId 1 is stale" + + " shuffle finalize request as shuffle blocks of a higher shuffleMergeId for the shuffle" + + " is already being pushed", re.getMessage()); pushResolver.finalizeShuffleMerge(new FinalizeShuffleMerge(TEST_APP, NO_ATTEMPT_ID, 0, 2)); MergedBlockMeta blockMeta = pushResolver.getMergedBlockMeta(TEST_APP, 0, 2, 0); @@ -1180,42 +1110,33 @@ public class RemoteBlockPushResolverSuite { stream2.onData(stream2.getID(), ByteBuffer.wrap(new byte[2])); stream2.onData(stream2.getID(), ByteBuffer.wrap(new byte[2])); stream1.onData(stream1.getID(), ByteBuffer.wrap(new byte[2])); - try { - // stream 1 push should be rejected as it is from an older shuffleMergeId - stream1.onComplete(stream1.getID()); - } catch (BlockPushNonFatalFailure e) { - BlockPushReturnCode errorCode = - (BlockPushReturnCode) BlockTransferMessage.Decoder.fromByteBuffer(e.getResponse()); - assertEquals(BlockPushNonFatalFailure.ReturnCode.STALE_BLOCK_PUSH.id(), - errorCode.returnCode); - assertEquals(errorCode.failureBlockId, stream1.getID()); - } + // stream 1 push should be rejected as it is from an older shuffleMergeId + BlockPushNonFatalFailure e = assertThrows(BlockPushNonFatalFailure.class, + () -> stream1.onComplete(stream1.getID())); + BlockPushReturnCode errorCode = + (BlockPushReturnCode) BlockTransferMessage.Decoder.fromByteBuffer(e.getResponse()); + assertEquals(BlockPushNonFatalFailure.ReturnCode.STALE_BLOCK_PUSH.id(), + errorCode.returnCode); + assertEquals(errorCode.failureBlockId, stream1.getID()); // stream 2 now completes stream2.onComplete(stream2.getID()); pushResolver.finalizeShuffleMerge(new FinalizeShuffleMerge(TEST_APP, NO_ATTEMPT_ID, 0, 2)); - try { - pushResolver.getMergedBlockMeta(TEST_APP, 0, 0, 0); - } catch(RuntimeException re) { - assertEquals("MergedBlockMeta fetch for shuffle 0 with shuffleMergeId 0 reduceId 0" - + " is stale shuffle block fetch request as shuffle blocks of a higher shuffleMergeId for" - + " the shuffle is available", re.getMessage()); - } - - try { - pushResolver.finalizeShuffleMerge(new FinalizeShuffleMerge(TEST_APP, NO_ATTEMPT_ID, 0, 1)); - } catch(RuntimeException re) { - assertEquals("Shuffle merge finalize request for shuffle 0 with shuffleMergeId 1 is stale" - + " shuffle finalize request as shuffle blocks of a higher shuffleMergeId for the shuffle" - + " is already being pushed", re.getMessage()); - } - try { - pushResolver.getMergedBlockData(TEST_APP, 0, 1, 0, 0); - } catch(RuntimeException re) { - assertEquals("MergedBlockData fetch for shuffle 0 with shuffleMergeId 1 reduceId 0" - + " is stale shuffle block fetch request as shuffle blocks of a higher shuffleMergeId for" - + " the shuffle is available", re.getMessage()); - } - + RuntimeException re0 = assertThrows(RuntimeException.class, + () -> pushResolver.getMergedBlockMeta(TEST_APP, 0, 0, 0)); + assertEquals("MergedBlockMeta fetch for shuffle 0 with shuffleMergeId 0 reduceId 0" + + " is stale shuffle block fetch request as shuffle blocks of a higher shuffleMergeId for" + + " the shuffle is available", re0.getMessage()); + RuntimeException re1 = assertThrows(RuntimeException.class, + () -> pushResolver.finalizeShuffleMerge( + new FinalizeShuffleMerge(TEST_APP, NO_ATTEMPT_ID, 0, 1))); + assertEquals("Shuffle merge finalize request for shuffle 0 with shuffleMergeId 1 is stale" + + " shuffle finalize request as shuffle blocks of a higher shuffleMergeId for the shuffle" + + " is already being pushed", re1.getMessage()); + RuntimeException re2 = assertThrows(RuntimeException.class, + () -> pushResolver.getMergedBlockData(TEST_APP, 0, 1, 0, 0)); + assertEquals("MergedBlockData fetch for shuffle 0 with shuffleMergeId 1 reduceId 0" + + " is stale shuffle block fetch request as shuffle blocks of a higher shuffleMergeId for" + + " the shuffle is available", re2.getMessage()); MergedBlockMeta blockMeta = pushResolver.getMergedBlockMeta(TEST_APP, 0, 2, 0); validateChunks(TEST_APP, 0, 2, 0, blockMeta, new int[]{4}, new int[][]{{0}}); } @@ -1324,37 +1245,28 @@ public class RemoteBlockPushResolverSuite { stream1.onComplete(stream1.getID()); //shuffle 1 0 is finalized pushResolver.finalizeShuffleMerge(new FinalizeShuffleMerge(TEST_APP, NO_ATTEMPT_ID, 1, 0)); - BlockPushNonFatalFailure errorToValidate = null; - try { - //shufflePush_1_0_0_200 is received by the server after finalization of shuffle 1 0 which - //should be rejected - StreamCallbackWithID failureCallback = pushResolver.receiveBlockDataAsStream( - new PushBlockStream(TEST_APP, NO_ATTEMPT_ID, 1, 0, 0, 200, 0)); - failureCallback.onComplete(failureCallback.getID()); - } catch (BlockPushNonFatalFailure e) { - BlockPushReturnCode errorCode = - (BlockPushReturnCode) BlockTransferMessage.Decoder.fromByteBuffer(e.getResponse()); - assertEquals(BlockPushNonFatalFailure.ReturnCode.TOO_LATE_BLOCK_PUSH.id(), - errorCode.returnCode); - errorToValidate = e; - assertEquals(errorCode.failureBlockId, "shufflePush_1_0_0_200"); - } - assertNotNull("shufflePush_1_0_0_200 should be rejected", errorToValidate); - try { - //shufflePush_1_0_1_100 is received by the server after finalization of shuffle 1 0 which - //should also be rejected - StreamCallbackWithID failureCallback = pushResolver.receiveBlockDataAsStream( - new PushBlockStream(TEST_APP, NO_ATTEMPT_ID, 1, 0, 1, 100, 0)); - failureCallback.onComplete(failureCallback.getID()); - } catch (BlockPushNonFatalFailure e) { - BlockPushReturnCode errorCode = - (BlockPushReturnCode) BlockTransferMessage.Decoder.fromByteBuffer(e.getResponse()); - assertEquals(BlockPushNonFatalFailure.ReturnCode.TOO_LATE_BLOCK_PUSH.id(), - errorCode.returnCode); - errorToValidate = e; - assertEquals(errorCode.failureBlockId, "shufflePush_1_0_1_100"); - } - assertNotNull("shufflePush_1_0_1_100 should be rejected", errorToValidate); + //shufflePush_1_0_0_200 is received by the server after finalization of shuffle 1 0 which + //should be rejected + StreamCallbackWithID failureCallback0 = pushResolver.receiveBlockDataAsStream( + new PushBlockStream(TEST_APP, NO_ATTEMPT_ID, 1, 0, 0, 200, 0)); + BlockPushNonFatalFailure e0 = assertThrows(BlockPushNonFatalFailure.class, + () -> failureCallback0.onComplete(failureCallback0.getID())); + BlockPushReturnCode errorCode0 = + (BlockPushReturnCode) BlockTransferMessage.Decoder.fromByteBuffer(e0.getResponse()); + assertEquals(BlockPushNonFatalFailure.ReturnCode.TOO_LATE_BLOCK_PUSH.id(), + errorCode0.returnCode); + assertEquals(errorCode0.failureBlockId, "shufflePush_1_0_0_200"); + //shufflePush_1_0_1_100 is received by the server after finalization of shuffle 1 0 which + //should also be rejected + StreamCallbackWithID failureCallback = pushResolver.receiveBlockDataAsStream( + new PushBlockStream(TEST_APP, NO_ATTEMPT_ID, 1, 0, 1, 100, 0)); + BlockPushNonFatalFailure e1 = assertThrows(BlockPushNonFatalFailure.class, + () -> failureCallback.onComplete(failureCallback.getID())); + BlockPushReturnCode errorCode1 = + (BlockPushReturnCode) BlockTransferMessage.Decoder.fromByteBuffer(e1.getResponse()); + assertEquals(BlockPushNonFatalFailure.ReturnCode.TOO_LATE_BLOCK_PUSH.id(), + errorCode1.returnCode); + assertEquals(errorCode1.failureBlockId, "shufflePush_1_0_1_100"); MergedBlockMeta blockMeta = pushResolver.getMergedBlockMeta(TEST_APP, 1, 0, 100); validateChunks(TEST_APP, 1, 0, 100, blockMeta, new int[]{4}, new int[][]{{0}}); removeApplication(TEST_APP); diff --git a/common/unsafe/src/test/java/org/apache/spark/unsafe/types/UTF8StringSuite.java b/common/unsafe/src/test/java/org/apache/spark/unsafe/types/UTF8StringSuite.java index e433dc0..f530c81 100644 --- a/common/unsafe/src/test/java/org/apache/spark/unsafe/types/UTF8StringSuite.java +++ b/common/unsafe/src/test/java/org/apache/spark/unsafe/types/UTF8StringSuite.java @@ -613,12 +613,8 @@ public class UTF8StringSuite { for (final long offset : offsets) { try { - fromAddress(test, BYTE_ARRAY_OFFSET + offset, test.length) - .writeTo(outputStream); - - throw new IllegalStateException(Long.toString(offset)); - } catch (ArrayIndexOutOfBoundsException e) { - // ignore + assertThrows(ArrayIndexOutOfBoundsException.class, + () -> fromAddress(test, BYTE_ARRAY_OFFSET + offset, test.length).writeTo(outputStream)); } finally { outputStream.reset(); } diff --git a/core/src/test/java/org/apache/spark/launcher/SparkLauncherSuite.java b/core/src/test/java/org/apache/spark/launcher/SparkLauncherSuite.java index 72b1245..5c88fb6 100644 --- a/core/src/test/java/org/apache/spark/launcher/SparkLauncherSuite.java +++ b/core/src/test/java/org/apache/spark/launcher/SparkLauncherSuite.java @@ -51,20 +51,11 @@ public class SparkLauncherSuite extends BaseSuite { SparkSubmitOptionParser opts = new SparkSubmitOptionParser(); launcher.addSparkArg(opts.HELP); - try { - launcher.addSparkArg(opts.PROXY_USER); - fail("Expected IllegalArgumentException."); - } catch (IllegalArgumentException e) { - // Expected. - } + assertThrows(IllegalArgumentException.class, () -> launcher.addSparkArg(opts.PROXY_USER)); launcher.addSparkArg(opts.PROXY_USER, "someUser"); - try { - launcher.addSparkArg(opts.HELP, "someValue"); - fail("Expected IllegalArgumentException."); - } catch (IllegalArgumentException e) { - // Expected. - } + assertThrows(IllegalArgumentException.class, + () -> launcher.addSparkArg(opts.HELP, "someValue")); launcher.addSparkArg("--future-argument"); launcher.addSparkArg("--future-argument", "someValue"); diff --git a/core/src/test/java/org/apache/spark/shuffle/sort/PackedRecordPointerSuite.java b/core/src/test/java/org/apache/spark/shuffle/sort/PackedRecordPointerSuite.java index 92bc740..1fd5aab 100644 --- a/core/src/test/java/org/apache/spark/shuffle/sort/PackedRecordPointerSuite.java +++ b/core/src/test/java/org/apache/spark/shuffle/sort/PackedRecordPointerSuite.java @@ -30,7 +30,7 @@ import static org.apache.spark.shuffle.sort.PackedRecordPointer.MAXIMUM_PAGE_SIZ import static org.apache.spark.shuffle.sort.PackedRecordPointer.MAXIMUM_PARTITION_ID; import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertNotEquals; -import static org.junit.Assert.assertTrue; +import static org.junit.Assert.assertThrows; public class PackedRecordPointerSuite { @@ -86,15 +86,9 @@ public class PackedRecordPointerSuite { @Test public void partitionIdsGreaterThanMaximumPartitionIdWillOverflowOrTriggerError() { PackedRecordPointer packedPointer = new PackedRecordPointer(); - boolean asserted = false; - try { - // Pointers greater than the maximum partition ID will overflow or trigger an assertion error - packedPointer.set(PackedRecordPointer.packPointer(0, MAXIMUM_PARTITION_ID + 1)); - } catch (AssertionError e ) { - // pass - asserted = true; - } - assertTrue(asserted); + // Pointers greater than the maximum partition ID will overflow or trigger an assertion error + assertThrows(AssertionError.class, + () -> packedPointer.set(PackedRecordPointer.packPointer(0, MAXIMUM_PARTITION_ID + 1))); assertNotEquals(MAXIMUM_PARTITION_ID + 1, packedPointer.getPartitionId()); } diff --git a/core/src/test/java/org/apache/spark/unsafe/map/AbstractBytesToBytesMapSuite.java b/core/src/test/java/org/apache/spark/unsafe/map/AbstractBytesToBytesMapSuite.java index a9c81c5..a20a2a0 100644 --- a/core/src/test/java/org/apache/spark/unsafe/map/AbstractBytesToBytesMapSuite.java +++ b/core/src/test/java/org/apache/spark/unsafe/map/AbstractBytesToBytesMapSuite.java @@ -618,33 +618,14 @@ public abstract class AbstractBytesToBytesMapSuite { @Test public void initialCapacityBoundsChecking() { - try { - new BytesToBytesMap(taskMemoryManager, 0, PAGE_SIZE_BYTES); - Assert.fail("Expected IllegalArgumentException to be thrown"); - } catch (IllegalArgumentException e) { - // expected exception - } - - try { - new BytesToBytesMap( - taskMemoryManager, - BytesToBytesMap.MAX_CAPACITY + 1, - PAGE_SIZE_BYTES); - Assert.fail("Expected IllegalArgumentException to be thrown"); - } catch (IllegalArgumentException e) { - // expected exception - } - - try { - new BytesToBytesMap( - taskMemoryManager, - 1, - TaskMemoryManager.MAXIMUM_PAGE_SIZE_BYTES + 1); - Assert.fail("Expected IllegalArgumentException to be thrown"); - } catch (IllegalArgumentException e) { - // expected exception - } - + assertThrows(IllegalArgumentException.class, + () -> new BytesToBytesMap(taskMemoryManager, 0, PAGE_SIZE_BYTES)); + assertThrows(IllegalArgumentException.class, + () -> new BytesToBytesMap(taskMemoryManager, + BytesToBytesMap.MAX_CAPACITY + 1, PAGE_SIZE_BYTES)); + assertThrows(IllegalArgumentException.class, + () -> new BytesToBytesMap(taskMemoryManager, 1, + TaskMemoryManager.MAXIMUM_PAGE_SIZE_BYTES + 1)); } @Test @@ -742,10 +723,7 @@ public abstract class AbstractBytesToBytesMapSuite { // Force OOM on next memory allocation. memoryManager.markExecutionAsOutOfMemoryOnce(); try { - map.reset(); - Assert.fail("Expected SparkOutOfMemoryError to be thrown"); - } catch (SparkOutOfMemoryError e) { - // Expected exception; do nothing. + assertThrows(SparkOutOfMemoryError.class, map::reset); } finally { map.free(); } diff --git a/core/src/test/java/test/org/apache/spark/JavaAPISuite.java b/core/src/test/java/test/org/apache/spark/JavaAPISuite.java index fd91237..cba43d9 100644 --- a/core/src/test/java/test/org/apache/spark/JavaAPISuite.java +++ b/core/src/test/java/test/org/apache/spark/JavaAPISuite.java @@ -1494,12 +1494,7 @@ public class JavaAPISuite implements Serializable { future.cancel(true); assertTrue(future.isCancelled()); assertTrue(future.isDone()); - try { - future.get(2000, TimeUnit.MILLISECONDS); - fail("Expected future.get() for cancelled job to throw CancellationException"); - } catch (CancellationException ignored) { - // pass - } + assertThrows(CancellationException.class, () -> future.get(2000, TimeUnit.MILLISECONDS)); } @Test @@ -1507,12 +1502,9 @@ public class JavaAPISuite implements Serializable { List<Integer> data = Arrays.asList(1, 2, 3, 4, 5); JavaRDD<Integer> rdd = sc.parallelize(data, 1); JavaFutureAction<Long> future = rdd.map(new BuggyMapFunction<>()).countAsync(); - try { - future.get(2, TimeUnit.SECONDS); - fail("Expected future.get() for failed job to throw ExecutionException"); - } catch (ExecutionException ee) { - assertTrue(Throwables.getStackTraceAsString(ee).contains("Custom exception!")); - } + ExecutionException ee = assertThrows(ExecutionException.class, + () -> future.get(2, TimeUnit.SECONDS)); + assertTrue(Throwables.getStackTraceAsString(ee).contains("Custom exception!")); assertTrue(future.isDone()); } diff --git a/launcher/src/test/java/org/apache/spark/launcher/CommandBuilderUtilsSuite.java b/launcher/src/test/java/org/apache/spark/launcher/CommandBuilderUtilsSuite.java index 22d9324..46cdffc 100644 --- a/launcher/src/test/java/org/apache/spark/launcher/CommandBuilderUtilsSuite.java +++ b/launcher/src/test/java/org/apache/spark/launcher/CommandBuilderUtilsSuite.java @@ -105,12 +105,7 @@ public class CommandBuilderUtilsSuite { } private static void testInvalidOpt(String opts) { - try { - parseOptionString(opts); - fail("Expected exception for invalid option string."); - } catch (IllegalArgumentException e) { - // pass. - } + assertThrows(IllegalArgumentException.class, () -> parseOptionString(opts)); } } diff --git a/launcher/src/test/java/org/apache/spark/launcher/LauncherServerSuite.java b/launcher/src/test/java/org/apache/spark/launcher/LauncherServerSuite.java index f8dc0ec..bf89de9 100644 --- a/launcher/src/test/java/org/apache/spark/launcher/LauncherServerSuite.java +++ b/launcher/src/test/java/org/apache/spark/launcher/LauncherServerSuite.java @@ -230,16 +230,14 @@ public class LauncherServerSuite extends BaseSuite { private void waitForError(TestClient client, String secret) throws Exception { final AtomicBoolean helloSent = new AtomicBoolean(); eventually(Duration.ofSeconds(1), Duration.ofMillis(10), () -> { - try { - if (!helloSent.get()) { - client.send(new Hello(secret, "1.4.0")); - helloSent.set(true); + if (!helloSent.get()) { + if (client.isOpen()) { + assertThrows(IOException.class, () -> client.send(new SetAppId("appId"))); } else { - client.send(new SetAppId("appId")); + assertThrows(IllegalStateException.class, + () -> client.send(new Hello(secret, "1.4.0"))); + helloSent.set(true); } - fail("Expected error but message went through."); - } catch (IllegalStateException | IOException e) { - // Expected. } }); } diff --git a/mllib/src/test/java/org/apache/spark/ml/classification/JavaRandomForestClassifierSuite.java b/mllib/src/test/java/org/apache/spark/ml/classification/JavaRandomForestClassifierSuite.java index dd98513..5308d61 100644 --- a/mllib/src/test/java/org/apache/spark/ml/classification/JavaRandomForestClassifierSuite.java +++ b/mllib/src/test/java/org/apache/spark/ml/classification/JavaRandomForestClassifierSuite.java @@ -73,12 +73,8 @@ public class JavaRandomForestClassifierSuite extends SharedSparkSession { } String[] invalidStrategies = {"-.1", "-.10", "-0.10", ".0", "0.0", "1.1", "0"}; for (String strategy : invalidStrategies) { - try { - rf.setFeatureSubsetStrategy(strategy); - Assert.fail("Expected exception to be thrown for invalid strategies"); - } catch (Exception e) { - Assert.assertTrue(e instanceof IllegalArgumentException); - } + Assert.assertThrows(IllegalArgumentException.class, + () -> rf.setFeatureSubsetStrategy(strategy)); } RandomForestClassificationModel model = rf.fit(dataFrame); diff --git a/mllib/src/test/java/org/apache/spark/ml/regression/JavaRandomForestRegressorSuite.java b/mllib/src/test/java/org/apache/spark/ml/regression/JavaRandomForestRegressorSuite.java index 4ba13e2..d08040d 100644 --- a/mllib/src/test/java/org/apache/spark/ml/regression/JavaRandomForestRegressorSuite.java +++ b/mllib/src/test/java/org/apache/spark/ml/regression/JavaRandomForestRegressorSuite.java @@ -75,12 +75,8 @@ public class JavaRandomForestRegressorSuite extends SharedSparkSession { } String[] invalidStrategies = {"-.1", "-.10", "-0.10", ".0", "0.0", "1.1", "0"}; for (String strategy : invalidStrategies) { - try { - rf.setFeatureSubsetStrategy(strategy); - Assert.fail("Expected exception to be thrown for invalid strategies"); - } catch (Exception e) { - Assert.assertTrue(e instanceof IllegalArgumentException); - } + Assert.assertThrows(IllegalArgumentException.class, + () -> rf.setFeatureSubsetStrategy(strategy)); } RandomForestRegressionModel model = rf.fit(dataFrame); diff --git a/mllib/src/test/java/org/apache/spark/ml/util/JavaDefaultReadWriteSuite.java b/mllib/src/test/java/org/apache/spark/ml/util/JavaDefaultReadWriteSuite.java index e4f678f..e4287c4 100644 --- a/mllib/src/test/java/org/apache/spark/ml/util/JavaDefaultReadWriteSuite.java +++ b/mllib/src/test/java/org/apache/spark/ml/util/JavaDefaultReadWriteSuite.java @@ -49,13 +49,7 @@ public class JavaDefaultReadWriteSuite extends SharedSparkSession { instance.set(instance.intParam(), 2); String outputPath = new File(tempDir, uid).getPath(); instance.save(outputPath); - try { - instance.save(outputPath); - Assert.fail( - "Write without overwrite enabled should fail if the output directory already exists."); - } catch (IOException e) { - // expected - } + Assert.assertThrows(IOException.class, () -> instance.save(outputPath)); instance.write().session(spark).overwrite().save(outputPath); MyParams newInstance = MyParams.load(outputPath); Assert.assertEquals("UID should match.", instance.uid(), newInstance.uid()); diff --git a/sql/catalyst/src/test/java/org/apache/spark/sql/catalyst/expressions/RowBasedKeyValueBatchSuite.java b/sql/catalyst/src/test/java/org/apache/spark/sql/catalyst/expressions/RowBasedKeyValueBatchSuite.java index d460a06..c7fdcc6 100644 --- a/sql/catalyst/src/test/java/org/apache/spark/sql/catalyst/expressions/RowBasedKeyValueBatchSuite.java +++ b/sql/catalyst/src/test/java/org/apache/spark/sql/catalyst/expressions/RowBasedKeyValueBatchSuite.java @@ -127,43 +127,10 @@ public class RowBasedKeyValueBatchSuite { try (RowBasedKeyValueBatch batch = RowBasedKeyValueBatch.allocate(keySchema, valueSchema, taskMemoryManager, DEFAULT_CAPACITY)) { Assert.assertEquals(0, batch.numRows()); - - boolean asserted = false; - try { - batch.getKeyRow(-1); - } catch (AssertionError e) { - // Expected exception; do nothing. - asserted = true; - } - Assert.assertTrue("Should not be able to get row -1", asserted); - - asserted = false; - try { - batch.getValueRow(-1); - } catch (AssertionError e) { - // Expected exception; do nothing. - asserted = true; - } - Assert.assertTrue("Should not be able to get row -1", asserted); - - asserted = false; - try { - batch.getKeyRow(0); - } catch (AssertionError e) { - // Expected exception; do nothing. - asserted = true; - } - Assert.assertTrue("Should not be able to get row 0 when batch is empty", asserted); - - asserted = false; - try { - batch.getValueRow(0); - } catch (AssertionError e) { - // Expected exception; do nothing. - asserted = true; - } - Assert.assertTrue("Should not be able to get row 0 when batch is empty", asserted); - + Assert.assertThrows(AssertionError.class, () -> batch.getKeyRow(-1)); + Assert.assertThrows(AssertionError.class, () -> batch.getValueRow(-1)); + Assert.assertThrows(AssertionError.class, () -> batch.getKeyRow(0)); + Assert.assertThrows(AssertionError.class, () -> batch.getValueRow(0)); Assert.assertFalse(batch.rowIterator().next()); } } @@ -199,23 +166,8 @@ public class RowBasedKeyValueBatchSuite { UnsafeRow retrievedValue2 = batch.getValueRow(2); Assert.assertTrue(checkValue(retrievedValue2, 3, 3)); - boolean asserted = false; - try { - batch.getKeyRow(3); - } catch (AssertionError e) { - // Expected exception; do nothing. - asserted = true; - } - Assert.assertTrue("Should not be able to get row 3", asserted); - - asserted = false; - try { - batch.getValueRow(3); - } catch (AssertionError e) { - // Expected exception; do nothing. - asserted = true; - } - Assert.assertTrue("Should not be able to get row 3", asserted); + Assert.assertThrows(AssertionError.class, () -> batch.getKeyRow(3)); + Assert.assertThrows(AssertionError.class, () -> batch.getValueRow(3)); } } diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/JavaBeanDeserializationSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/JavaBeanDeserializationSuite.java index af0a22b..3d78e06 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/JavaBeanDeserializationSuite.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/JavaBeanDeserializationSuite.java @@ -183,18 +183,9 @@ public class JavaBeanDeserializationSuite implements Serializable { Dataset<Row> dataFrame = spark.createDataFrame(inputRows, schema); - try { - dataFrame.as(encoder).collect(); - Assert.fail("Expected AnalysisException, but passed."); - } catch (Throwable e) { - // Here we need to handle weird case: compiler complains AnalysisException never be thrown - // in try statement, but it can be thrown actually. Maybe Scala-Java interop issue? - if (e instanceof AnalysisException) { - Assert.assertTrue(e.getMessage().contains("Cannot up cast ")); - } else { - throw e; - } - } + AnalysisException e = Assert.assertThrows(AnalysisException.class, + () -> dataFrame.as(encoder).collect()); + Assert.assertTrue(e.getMessage().contains("Cannot up cast ")); } private static Row createRecordSpark22000Row(Long index) { diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/JavaColumnExpressionSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/JavaColumnExpressionSuite.java index 4478742..7f9fdbd 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/JavaColumnExpressionSuite.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/JavaColumnExpressionSuite.java @@ -79,14 +79,12 @@ public class JavaColumnExpressionSuite { createStructField("a", IntegerType, false), createStructField("b", createArrayType(IntegerType, false), false))); Dataset<Row> df = spark.createDataFrame(rows, schema); - try { - df.filter(df.col("a").isInCollection(Arrays.asList(new Column("b")))); - Assert.fail("Expected org.apache.spark.sql.AnalysisException"); - } catch (Exception e) { - Arrays.asList("cannot resolve", - "due to data type mismatch: Arguments must be same type but were") - .forEach(s -> Assert.assertTrue( - e.getMessage().toLowerCase(Locale.ROOT).contains(s.toLowerCase(Locale.ROOT)))); - } + Exception e = Assert.assertThrows(Exception.class, + () -> df.filter(df.col("a").isInCollection(Arrays.asList(new Column("b"))))); + Arrays.asList("cannot resolve", + "due to data type mismatch: Arguments must be same type but were") + .forEach(s -> + Assert.assertTrue(e.getMessage().toLowerCase(Locale.ROOT) + .contains(s.toLowerCase(Locale.ROOT)))); } } --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org