This is an automated email from the ASF dual-hosted git repository.

vanzin pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 5668c42  [SPARK-27021][CORE] Cleanup of Netty event loop group for 
shuffle chunk fetch requests
5668c42 is described below

commit 5668c42edf20bc577305437622272bf803b6019e
Author: “attilapiros” <piros.attila.zs...@gmail.com>
AuthorDate: Tue Mar 5 12:31:06 2019 -0800

    [SPARK-27021][CORE] Cleanup of Netty event loop group for shuffle chunk 
fetch requests
    
    ## What changes were proposed in this pull request?
    
    Creating an Netty `EventLoopGroup` leads to creating a new Thread pool for 
handling the events. For stopping the threads of the pool the event loop group 
should be shut down which is properly done for transport servers and clients by 
calling for example the `shutdownGracefully()` method (for details see the 
`close()` method of `TransportClientFactory` and `TransportServer`). But there 
is a separate event loop group for shuffle chunk fetch requests which is in 
pipeline for handling fet [...]
    
    ## How was this patch tested?
    
    With existing unittest.
    
    This leak is in the production system too but its effect is spiking in the 
unittest.
    
    Checking the core unittest logs before the PR:
    ```
    $ grep "LEAK IN SUITE" unit-tests.log | grep -o shuffle-chunk-fetch-handler 
| wc -l
    381
    ```
    
    And after the PR without whitelisting in thread audit and with an extra 
`await` after the
    ` chunkFetchWorkers.shutdownGracefully()`:
    ```
    $ grep "LEAK IN SUITE" unit-tests.log | grep -o shuffle-chunk-fetch-handler 
| wc -l
    0
    ```
    
    Closes #23930 from attilapiros/SPARK-27021.
    
    Authored-by: “attilapiros” <piros.attila.zs...@gmail.com>
    Signed-off-by: Marcelo Vanzin <van...@cloudera.com>
---
 .../org/apache/spark/network/TransportContext.java |  15 +--
 .../spark/network/ChunkFetchIntegrationSuite.java  |   4 +-
 .../network/RequestTimeoutIntegrationSuite.java    |  10 +-
 .../apache/spark/network/RpcIntegrationSuite.java  |   4 +-
 .../java/org/apache/spark/network/StreamSuite.java |   4 +-
 .../spark/network/TransportClientFactorySuite.java |  78 +++++++-------
 .../spark/network/crypto/AuthIntegrationSuite.java |   3 +
 .../apache/spark/network/sasl/SparkSaslSuite.java  |   6 +-
 .../network/util/NettyMemoryMetricsSuite.java      |   5 +-
 .../spark/network/sasl/SaslIntegrationSuite.java   |  91 ++++++++--------
 .../shuffle/ExternalShuffleIntegrationSuite.java   |   4 +-
 .../shuffle/ExternalShuffleSecuritySuite.java      |  10 +-
 .../spark/network/yarn/YarnShuffleService.java     |   7 +-
 .../spark/deploy/ExternalShuffleService.scala      |   8 +-
 .../network/netty/NettyBlockTransferService.scala  |   3 +
 .../org/apache/spark/rpc/netty/NettyRpcEnv.scala   |   3 +
 .../apache/spark/ExternalShuffleServiceSuite.scala |  15 ++-
 .../test/scala/org/apache/spark/ThreadAudit.scala  |  16 ++-
 .../apache/spark/storage/BlockManagerSuite.scala   | 115 +++++++++++----------
 19 files changed, 228 insertions(+), 173 deletions(-)

diff --git 
a/common/network-common/src/main/java/org/apache/spark/network/TransportContext.java
 
b/common/network-common/src/main/java/org/apache/spark/network/TransportContext.java
index 0bc5dd5..d99b9bd 100644
--- 
a/common/network-common/src/main/java/org/apache/spark/network/TransportContext.java
+++ 
b/common/network-common/src/main/java/org/apache/spark/network/TransportContext.java
@@ -17,6 +17,7 @@
 
 package org.apache.spark.network;
 
+import java.io.Closeable;
 import java.util.ArrayList;
 import java.util.List;
 
@@ -60,13 +61,12 @@ import org.apache.spark.network.util.TransportFrameDecoder;
  * channel. As each TransportChannelHandler contains a TransportClient, this 
enables server
  * processes to send messages back to the client on an existing channel.
  */
-public class TransportContext {
+public class TransportContext implements Closeable {
   private static final Logger logger = 
LoggerFactory.getLogger(TransportContext.class);
 
   private final TransportConf conf;
   private final RpcHandler rpcHandler;
   private final boolean closeIdleConnections;
-  private final boolean isClientOnly;
   // Number of registered connections to the shuffle service
   private Counter registeredConnections = new Counter();
 
@@ -120,7 +120,6 @@ public class TransportContext {
     this.conf = conf;
     this.rpcHandler = rpcHandler;
     this.closeIdleConnections = closeIdleConnections;
-    this.isClientOnly = isClientOnly;
 
     if (conf.getModuleName() != null &&
         conf.getModuleName().equalsIgnoreCase("shuffle") &&
@@ -200,9 +199,7 @@ public class TransportContext {
         // would require more logic to guarantee if this were not part of the 
same event loop.
         .addLast("handler", channelHandler);
       // Use a separate EventLoopGroup to handle ChunkFetchRequest messages 
for shuffle rpcs.
-      if (conf.getModuleName() != null &&
-          conf.getModuleName().equalsIgnoreCase("shuffle")
-          && !isClientOnly) {
+      if (chunkFetchWorkers != null) {
         pipeline.addLast(chunkFetchWorkers, "chunkFetchHandler", 
chunkFetchHandler);
       }
       return channelHandler;
@@ -240,4 +237,10 @@ public class TransportContext {
   public Counter getRegisteredConnections() {
     return registeredConnections;
   }
+
+  public void close() {
+    if (chunkFetchWorkers != null) {
+      chunkFetchWorkers.shutdownGracefully();
+    }
+  }
 }
diff --git 
a/common/network-common/src/test/java/org/apache/spark/network/ChunkFetchIntegrationSuite.java
 
b/common/network-common/src/test/java/org/apache/spark/network/ChunkFetchIntegrationSuite.java
index ab4dd04..5999b62 100644
--- 
a/common/network-common/src/test/java/org/apache/spark/network/ChunkFetchIntegrationSuite.java
+++ 
b/common/network-common/src/test/java/org/apache/spark/network/ChunkFetchIntegrationSuite.java
@@ -56,6 +56,7 @@ public class ChunkFetchIntegrationSuite {
   static final int BUFFER_CHUNK_INDEX = 0;
   static final int FILE_CHUNK_INDEX = 1;
 
+  static TransportContext context;
   static TransportServer server;
   static TransportClientFactory clientFactory;
   static StreamManager streamManager;
@@ -117,7 +118,7 @@ public class ChunkFetchIntegrationSuite {
         return streamManager;
       }
     };
-    TransportContext context = new TransportContext(conf, handler);
+    context = new TransportContext(conf, handler);
     server = context.createServer();
     clientFactory = context.createClientFactory();
   }
@@ -127,6 +128,7 @@ public class ChunkFetchIntegrationSuite {
     bufferChunk.release();
     server.close();
     clientFactory.close();
+    context.close();
     testFile.delete();
   }
 
diff --git 
a/common/network-common/src/test/java/org/apache/spark/network/RequestTimeoutIntegrationSuite.java
 
b/common/network-common/src/test/java/org/apache/spark/network/RequestTimeoutIntegrationSuite.java
index c0724e0..15a28ba 100644
--- 
a/common/network-common/src/test/java/org/apache/spark/network/RequestTimeoutIntegrationSuite.java
+++ 
b/common/network-common/src/test/java/org/apache/spark/network/RequestTimeoutIntegrationSuite.java
@@ -48,6 +48,7 @@ import java.util.concurrent.TimeUnit;
  */
 public class RequestTimeoutIntegrationSuite {
 
+  private TransportContext context;
   private TransportServer server;
   private TransportClientFactory clientFactory;
 
@@ -79,6 +80,9 @@ public class RequestTimeoutIntegrationSuite {
     if (clientFactory != null) {
       clientFactory.close();
     }
+    if (context !=  null) {
+      context.close();
+    }
   }
 
   // Basic suite: First request completes quickly, and second waits for longer 
than network timeout.
@@ -106,7 +110,7 @@ public class RequestTimeoutIntegrationSuite {
       }
     };
 
-    TransportContext context = new TransportContext(conf, handler);
+    context = new TransportContext(conf, handler);
     server = context.createServer();
     clientFactory = context.createClientFactory();
     TransportClient client = 
clientFactory.createClient(TestUtils.getLocalHost(), server.getPort());
@@ -153,7 +157,7 @@ public class RequestTimeoutIntegrationSuite {
       }
     };
 
-    TransportContext context = new TransportContext(conf, handler);
+    context = new TransportContext(conf, handler);
     server = context.createServer();
     clientFactory = context.createClientFactory();
 
@@ -204,7 +208,7 @@ public class RequestTimeoutIntegrationSuite {
       }
     };
 
-    TransportContext context = new TransportContext(conf, handler);
+    context = new TransportContext(conf, handler);
     server = context.createServer();
     clientFactory = context.createClientFactory();
     TransportClient client = 
clientFactory.createClient(TestUtils.getLocalHost(), server.getPort());
diff --git 
a/common/network-common/src/test/java/org/apache/spark/network/RpcIntegrationSuite.java
 
b/common/network-common/src/test/java/org/apache/spark/network/RpcIntegrationSuite.java
index 1c0aa4d..117f1e4 100644
--- 
a/common/network-common/src/test/java/org/apache/spark/network/RpcIntegrationSuite.java
+++ 
b/common/network-common/src/test/java/org/apache/spark/network/RpcIntegrationSuite.java
@@ -44,6 +44,7 @@ import org.apache.spark.network.util.TransportConf;
 
 public class RpcIntegrationSuite {
   static TransportConf conf;
+  static TransportContext context;
   static TransportServer server;
   static TransportClientFactory clientFactory;
   static RpcHandler rpcHandler;
@@ -90,7 +91,7 @@ public class RpcIntegrationSuite {
       @Override
       public StreamManager getStreamManager() { return new 
OneForOneStreamManager(); }
     };
-    TransportContext context = new TransportContext(conf, rpcHandler);
+    context = new TransportContext(conf, rpcHandler);
     server = context.createServer();
     clientFactory = context.createClientFactory();
     oneWayMsgs = new ArrayList<>();
@@ -160,6 +161,7 @@ public class RpcIntegrationSuite {
   public static void tearDown() {
     server.close();
     clientFactory.close();
+    context.close();
     testData.cleanup();
   }
 
diff --git 
a/common/network-common/src/test/java/org/apache/spark/network/StreamSuite.java 
b/common/network-common/src/test/java/org/apache/spark/network/StreamSuite.java
index f3050cb..485d8ad 100644
--- 
a/common/network-common/src/test/java/org/apache/spark/network/StreamSuite.java
+++ 
b/common/network-common/src/test/java/org/apache/spark/network/StreamSuite.java
@@ -51,6 +51,7 @@ public class StreamSuite {
   private static final String[] STREAMS = StreamTestHelper.STREAMS;
   private static StreamTestHelper testData;
 
+  private static TransportContext context;
   private static TransportServer server;
   private static TransportClientFactory clientFactory;
 
@@ -93,7 +94,7 @@ public class StreamSuite {
         return streamManager;
       }
     };
-    TransportContext context = new TransportContext(conf, handler);
+    context = new TransportContext(conf, handler);
     server = context.createServer();
     clientFactory = context.createClientFactory();
   }
@@ -103,6 +104,7 @@ public class StreamSuite {
     server.close();
     clientFactory.close();
     testData.cleanup();
+    context.close();
   }
 
   @Test
diff --git 
a/common/network-common/src/test/java/org/apache/spark/network/TransportClientFactorySuite.java
 
b/common/network-common/src/test/java/org/apache/spark/network/TransportClientFactorySuite.java
index e95d25f..2c62114 100644
--- 
a/common/network-common/src/test/java/org/apache/spark/network/TransportClientFactorySuite.java
+++ 
b/common/network-common/src/test/java/org/apache/spark/network/TransportClientFactorySuite.java
@@ -64,6 +64,7 @@ public class TransportClientFactorySuite {
   public void tearDown() {
     JavaUtils.closeQuietly(server1);
     JavaUtils.closeQuietly(server2);
+    JavaUtils.closeQuietly(context);
   }
 
   /**
@@ -80,49 +81,50 @@ public class TransportClientFactorySuite {
     TransportConf conf = new TransportConf("shuffle", new 
MapConfigProvider(configMap));
 
     RpcHandler rpcHandler = new NoOpRpcHandler();
-    TransportContext context = new TransportContext(conf, rpcHandler);
-    TransportClientFactory factory = context.createClientFactory();
-    Set<TransportClient> clients = Collections.synchronizedSet(
-      new HashSet<TransportClient>());
-
-    AtomicInteger failed = new AtomicInteger();
-    Thread[] attempts = new Thread[maxConnections * 10];
-
-    // Launch a bunch of threads to create new clients.
-    for (int i = 0; i < attempts.length; i++) {
-      attempts[i] = new Thread(() -> {
-        try {
-          TransportClient client =
-            factory.createClient(TestUtils.getLocalHost(), server1.getPort());
-          assertTrue(client.isActive());
-          clients.add(client);
-        } catch (IOException e) {
-          failed.incrementAndGet();
-        } catch (InterruptedException e) {
-          throw new RuntimeException(e);
+    try (TransportContext context = new TransportContext(conf, rpcHandler)) {
+      TransportClientFactory factory = context.createClientFactory();
+      Set<TransportClient> clients = Collections.synchronizedSet(
+          new HashSet<TransportClient>());
+
+      AtomicInteger failed = new AtomicInteger();
+      Thread[] attempts = new Thread[maxConnections * 10];
+
+      // Launch a bunch of threads to create new clients.
+      for (int i = 0; i < attempts.length; i++) {
+        attempts[i] = new Thread(() -> {
+          try {
+            TransportClient client =
+                factory.createClient(TestUtils.getLocalHost(), 
server1.getPort());
+            assertTrue(client.isActive());
+            clients.add(client);
+          } catch (IOException e) {
+            failed.incrementAndGet();
+          } catch (InterruptedException e) {
+            throw new RuntimeException(e);
+          }
+        });
+
+        if (concurrent) {
+          attempts[i].start();
+        } else {
+          attempts[i].run();
         }
-      });
+      }
 
-      if (concurrent) {
-        attempts[i].start();
-      } else {
-        attempts[i].run();
+      // Wait until all the threads complete.
+      for (Thread attempt : attempts) {
+        attempt.join();
       }
-    }
 
-    // Wait until all the threads complete.
-    for (Thread attempt : attempts) {
-      attempt.join();
-    }
+      Assert.assertEquals(0, failed.get());
+      Assert.assertEquals(clients.size(), maxConnections);
 
-    Assert.assertEquals(0, failed.get());
-    Assert.assertEquals(clients.size(), maxConnections);
+      for (TransportClient client : clients) {
+        client.close();
+      }
 
-    for (TransportClient client : clients) {
-      client.close();
+      factory.close();
     }
-
-    factory.close();
   }
 
   @Test
@@ -204,8 +206,8 @@ public class TransportClientFactorySuite {
         throw new UnsupportedOperationException();
       }
     });
-    TransportContext context = new TransportContext(conf, new 
NoOpRpcHandler(), true);
-    try (TransportClientFactory factory = context.createClientFactory()) {
+    try (TransportContext context = new TransportContext(conf, new 
NoOpRpcHandler(), true);
+         TransportClientFactory factory = context.createClientFactory()) {
       TransportClient c1 = factory.createClient(TestUtils.getLocalHost(), 
server1.getPort());
       assertTrue(c1.isActive());
       long expiredTime = System.currentTimeMillis() + 10000; // 10 seconds
diff --git 
a/common/network-common/src/test/java/org/apache/spark/network/crypto/AuthIntegrationSuite.java
 
b/common/network-common/src/test/java/org/apache/spark/network/crypto/AuthIntegrationSuite.java
index 8751944..8a0ff54 100644
--- 
a/common/network-common/src/test/java/org/apache/spark/network/crypto/AuthIntegrationSuite.java
+++ 
b/common/network-common/src/test/java/org/apache/spark/network/crypto/AuthIntegrationSuite.java
@@ -196,6 +196,9 @@ public class AuthIntegrationSuite {
       if (server != null) {
         server.close();
       }
+      if (ctx != null) {
+        ctx.close();
+      }
     }
 
     private SecretKeyHolder createKeyHolder(String secret) {
diff --git 
a/common/network-common/src/test/java/org/apache/spark/network/sasl/SparkSaslSuite.java
 
b/common/network-common/src/test/java/org/apache/spark/network/sasl/SparkSaslSuite.java
index 59adf97..cf2d72f 100644
--- 
a/common/network-common/src/test/java/org/apache/spark/network/sasl/SparkSaslSuite.java
+++ 
b/common/network-common/src/test/java/org/apache/spark/network/sasl/SparkSaslSuite.java
@@ -365,6 +365,7 @@ public class SparkSaslSuite {
 
     final TransportClient client;
     final TransportServer server;
+    final TransportContext ctx;
 
     private final boolean encrypt;
     private final boolean disableClientEncryption;
@@ -396,7 +397,7 @@ public class SparkSaslSuite {
       when(keyHolder.getSaslUser(anyString())).thenReturn("user");
       when(keyHolder.getSecretKey(anyString())).thenReturn("secret");
 
-      TransportContext ctx = new TransportContext(conf, rpcHandler);
+      this.ctx = new TransportContext(conf, rpcHandler);
 
       this.checker = new 
EncryptionCheckerBootstrap(SaslEncryption.ENCRYPTION_HANDLER_NAME);
 
@@ -431,6 +432,9 @@ public class SparkSaslSuite {
       if (server != null) {
         server.close();
       }
+      if (ctx != null) {
+        ctx.close();
+      }
     }
 
   }
diff --git 
a/common/network-common/src/test/java/org/apache/spark/network/util/NettyMemoryMetricsSuite.java
 
b/common/network-common/src/test/java/org/apache/spark/network/util/NettyMemoryMetricsSuite.java
index 400b385..f049cad 100644
--- 
a/common/network-common/src/test/java/org/apache/spark/network/util/NettyMemoryMetricsSuite.java
+++ 
b/common/network-common/src/test/java/org/apache/spark/network/util/NettyMemoryMetricsSuite.java
@@ -60,11 +60,14 @@ public class NettyMemoryMetricsSuite {
       JavaUtils.closeQuietly(clientFactory);
       clientFactory = null;
     }
-
     if (server != null) {
       JavaUtils.closeQuietly(server);
       server = null;
     }
+    if (context != null) {
+      JavaUtils.closeQuietly(context);
+      context = null;
+    }
   }
 
   @Test
diff --git 
a/common/network-shuffle/src/test/java/org/apache/spark/network/sasl/SaslIntegrationSuite.java
 
b/common/network-shuffle/src/test/java/org/apache/spark/network/sasl/SaslIntegrationSuite.java
index 02e6eb3..57c1c5e 100644
--- 
a/common/network-shuffle/src/test/java/org/apache/spark/network/sasl/SaslIntegrationSuite.java
+++ 
b/common/network-shuffle/src/test/java/org/apache/spark/network/sasl/SaslIntegrationSuite.java
@@ -91,6 +91,7 @@ public class SaslIntegrationSuite {
   @AfterClass
   public static void afterAll() {
     server.close();
+    context.close();
   }
 
   @After
@@ -153,13 +154,14 @@ public class SaslIntegrationSuite {
   @Test
   public void testNoSaslServer() {
     RpcHandler handler = new TestRpcHandler();
-    TransportContext context = new TransportContext(conf, handler);
-    clientFactory = context.createClientFactory(
-      Arrays.asList(new SaslClientBootstrap(conf, "app-1", secretKeyHolder)));
-    try (TransportServer server = context.createServer()) {
-      clientFactory.createClient(TestUtils.getLocalHost(), server.getPort());
-    } catch (Exception e) {
-      assertTrue(e.getMessage(), e.getMessage().contains("Digest-challenge 
format violation"));
+    try (TransportContext context = new TransportContext(conf, handler)) {
+      clientFactory = context.createClientFactory(
+          Arrays.asList(new SaslClientBootstrap(conf, "app-1", 
secretKeyHolder)));
+      try (TransportServer server = context.createServer()) {
+        clientFactory.createClient(TestUtils.getLocalHost(), server.getPort());
+      } catch (Exception e) {
+        assertTrue(e.getMessage(), e.getMessage().contains("Digest-challenge 
format violation"));
+      }
     }
   }
 
@@ -174,18 +176,15 @@ public class SaslIntegrationSuite {
     ExternalShuffleBlockHandler blockHandler = new ExternalShuffleBlockHandler(
       new OneForOneStreamManager(), blockResolver);
     TransportServerBootstrap bootstrap = new SaslServerBootstrap(conf, 
secretKeyHolder);
-    TransportContext blockServerContext = new TransportContext(conf, 
blockHandler);
-    TransportServer blockServer = 
blockServerContext.createServer(Arrays.asList(bootstrap));
 
-    TransportClient client1 = null;
-    TransportClient client2 = null;
-    TransportClientFactory clientFactory2 = null;
-    try {
+    try (
+      TransportContext blockServerContext = new TransportContext(conf, 
blockHandler);
+      TransportServer blockServer = 
blockServerContext.createServer(Arrays.asList(bootstrap));
       // Create a client, and make a request to fetch blocks from a different 
app.
-      clientFactory = blockServerContext.createClientFactory(
+      TransportClientFactory clientFactory1 = 
blockServerContext.createClientFactory(
           Arrays.asList(new SaslClientBootstrap(conf, "app-1", 
secretKeyHolder)));
-      client1 = clientFactory.createClient(TestUtils.getLocalHost(),
-        blockServer.getPort());
+      TransportClient client1 = clientFactory1.createClient(
+          TestUtils.getLocalHost(), blockServer.getPort())) {
 
       AtomicReference<Throwable> exception = new AtomicReference<>();
 
@@ -223,41 +222,33 @@ public class SaslIntegrationSuite {
       StreamHandle stream = (StreamHandle) 
BlockTransferMessage.Decoder.fromByteBuffer(response);
       long streamId = stream.streamId;
 
-      // Create a second client, authenticated with a different app ID, and 
try to read from
-      // the stream created for the previous app.
-      clientFactory2 = blockServerContext.createClientFactory(
-          Arrays.asList(new SaslClientBootstrap(conf, "app-2", 
secretKeyHolder)));
-      client2 = clientFactory2.createClient(TestUtils.getLocalHost(),
-        blockServer.getPort());
-
-      CountDownLatch chunkReceivedLatch = new CountDownLatch(1);
-      ChunkReceivedCallback callback = new ChunkReceivedCallback() {
-        @Override
-        public void onSuccess(int chunkIndex, ManagedBuffer buffer) {
-          chunkReceivedLatch.countDown();
-        }
-        @Override
-        public void onFailure(int chunkIndex, Throwable t) {
-          exception.set(t);
-          chunkReceivedLatch.countDown();
-        }
-      };
-
-      exception.set(null);
-      client2.fetchChunk(streamId, 0, callback);
-      chunkReceivedLatch.await();
-      checkSecurityException(exception.get());
-    } finally {
-      if (client1 != null) {
-        client1.close();
-      }
-      if (client2 != null) {
-        client2.close();
-      }
-      if (clientFactory2 != null) {
-        clientFactory2.close();
+      try (
+        // Create a second client, authenticated with a different app ID, and 
try to read from
+        // the stream created for the previous app.
+        TransportClientFactory clientFactory2 = 
blockServerContext.createClientFactory(
+            Arrays.asList(new SaslClientBootstrap(conf, "app-2", 
secretKeyHolder)));
+        TransportClient client2 = clientFactory2.createClient(
+            TestUtils.getLocalHost(), blockServer.getPort())
+      ) {
+        CountDownLatch chunkReceivedLatch = new CountDownLatch(1);
+        ChunkReceivedCallback callback = new ChunkReceivedCallback() {
+          @Override
+          public void onSuccess(int chunkIndex, ManagedBuffer buffer) {
+            chunkReceivedLatch.countDown();
+          }
+
+          @Override
+          public void onFailure(int chunkIndex, Throwable t) {
+            exception.set(t);
+            chunkReceivedLatch.countDown();
+          }
+        };
+
+        exception.set(null);
+        client2.fetchChunk(streamId, 0, callback);
+        chunkReceivedLatch.await();
+        checkSecurityException(exception.get());
       }
-      blockServer.close();
     }
   }
 
diff --git 
a/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleIntegrationSuite.java
 
b/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleIntegrationSuite.java
index 526b96b..f5b1ec9 100644
--- 
a/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleIntegrationSuite.java
+++ 
b/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleIntegrationSuite.java
@@ -58,6 +58,7 @@ public class ExternalShuffleIntegrationSuite {
   static ExternalShuffleBlockHandler handler;
   static TransportServer server;
   static TransportConf conf;
+  static TransportContext transportContext;
 
   static byte[][] exec0Blocks = new byte[][] {
     new byte[123],
@@ -87,7 +88,7 @@ public class ExternalShuffleIntegrationSuite {
 
     conf = new TransportConf("shuffle", MapConfigProvider.EMPTY);
     handler = new ExternalShuffleBlockHandler(conf, null);
-    TransportContext transportContext = new TransportContext(conf, handler);
+    transportContext = new TransportContext(conf, handler);
     server = transportContext.createServer();
   }
 
@@ -95,6 +96,7 @@ public class ExternalShuffleIntegrationSuite {
   public static void afterAll() {
     dataContext0.cleanup();
     server.close();
+    transportContext.close();
   }
 
   @After
diff --git 
a/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleSecuritySuite.java
 
b/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleSecuritySuite.java
index 82caf39..67f79021 100644
--- 
a/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleSecuritySuite.java
+++ 
b/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleSecuritySuite.java
@@ -41,14 +41,14 @@ public class ExternalShuffleSecuritySuite {
 
   TransportConf conf = new TransportConf("shuffle", MapConfigProvider.EMPTY);
   TransportServer server;
+  TransportContext transportContext;
 
   @Before
   public void beforeEach() throws IOException {
-    TransportContext context =
-      new TransportContext(conf, new ExternalShuffleBlockHandler(conf, null));
+    transportContext = new TransportContext(conf, new 
ExternalShuffleBlockHandler(conf, null));
     TransportServerBootstrap bootstrap = new SaslServerBootstrap(conf,
         new TestSecretKeyHolder("my-app-id", "secret"));
-    this.server = context.createServer(Arrays.asList(bootstrap));
+    this.server = transportContext.createServer(Arrays.asList(bootstrap));
   }
 
   @After
@@ -57,6 +57,10 @@ public class ExternalShuffleSecuritySuite {
       server.close();
       server = null;
     }
+    if (transportContext != null) {
+      transportContext.close();
+      transportContext = null;
+    }
   }
 
   @Test
diff --git 
a/common/network-yarn/src/main/java/org/apache/spark/network/yarn/YarnShuffleService.java
 
b/common/network-yarn/src/main/java/org/apache/spark/network/yarn/YarnShuffleService.java
index 7e8d3b2..25592e9 100644
--- 
a/common/network-yarn/src/main/java/org/apache/spark/network/yarn/YarnShuffleService.java
+++ 
b/common/network-yarn/src/main/java/org/apache/spark/network/yarn/YarnShuffleService.java
@@ -113,6 +113,8 @@ public class YarnShuffleService extends AuxiliaryService {
   // The actual server that serves shuffle files
   private TransportServer shuffleServer = null;
 
+  private TransportContext transportContext = null;
+
   private Configuration _conf = null;
 
   // The recovery path used to shuffle service recovery
@@ -184,7 +186,7 @@ public class YarnShuffleService extends AuxiliaryService {
 
       int port = conf.getInt(
         SPARK_SHUFFLE_SERVICE_PORT_KEY, DEFAULT_SPARK_SHUFFLE_SERVICE_PORT);
-      TransportContext transportContext = new TransportContext(transportConf, 
blockHandler);
+      transportContext = new TransportContext(transportConf, blockHandler);
       shuffleServer = transportContext.createServer(port, bootstraps);
       // the port should normally be fixed, but for tests its useful to find 
an open port
       port = shuffleServer.getPort();
@@ -318,6 +320,9 @@ public class YarnShuffleService extends AuxiliaryService {
       if (shuffleServer != null) {
         shuffleServer.close();
       }
+      if (transportContext != null) {
+        transportContext.close();
+      }
       if (blockHandler != null) {
         blockHandler.close();
       }
diff --git 
a/core/src/main/scala/org/apache/spark/deploy/ExternalShuffleService.scala 
b/core/src/main/scala/org/apache/spark/deploy/ExternalShuffleService.scala
index edfd2ea..12ed189 100644
--- a/core/src/main/scala/org/apache/spark/deploy/ExternalShuffleService.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/ExternalShuffleService.scala
@@ -52,8 +52,7 @@ class ExternalShuffleService(sparkConf: SparkConf, 
securityManager: SecurityMana
   private val transportConf =
     SparkTransportConf.fromSparkConf(sparkConf, "shuffle", numUsableCores = 0)
   private val blockHandler = newShuffleBlockHandler(transportConf)
-  private val transportContext: TransportContext =
-    new TransportContext(transportConf, blockHandler, true)
+  private var transportContext: TransportContext = _
 
   private var server: TransportServer = _
 
@@ -82,6 +81,7 @@ class ExternalShuffleService(sparkConf: SparkConf, 
securityManager: SecurityMana
       } else {
         Nil
       }
+    transportContext = new TransportContext(transportConf, blockHandler, true)
     server = transportContext.createServer(port, bootstraps.asJava)
 
     shuffleServiceSource.registerMetricSet(server.getAllMetrics)
@@ -107,6 +107,10 @@ class ExternalShuffleService(sparkConf: SparkConf, 
securityManager: SecurityMana
       server.close()
       server = null
     }
+    if (transportContext != null) {
+      transportContext.close()
+      transportContext = null
+    }
   }
 }
 
diff --git 
a/core/src/main/scala/org/apache/spark/network/netty/NettyBlockTransferService.scala
 
b/core/src/main/scala/org/apache/spark/network/netty/NettyBlockTransferService.scala
index dc55685..864e8ad 100644
--- 
a/core/src/main/scala/org/apache/spark/network/netty/NettyBlockTransferService.scala
+++ 
b/core/src/main/scala/org/apache/spark/network/netty/NettyBlockTransferService.scala
@@ -182,5 +182,8 @@ private[spark] class NettyBlockTransferService(
     if (clientFactory != null) {
       clientFactory.close()
     }
+    if (transportContext != null) {
+      transportContext.close()
+    }
   }
 }
diff --git a/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcEnv.scala 
b/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcEnv.scala
index 2540196..472db45 100644
--- a/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcEnv.scala
+++ b/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcEnv.scala
@@ -315,6 +315,9 @@ private[netty] class NettyRpcEnv(
     if (fileDownloadFactory != null) {
       fileDownloadFactory.close()
     }
+    if (transportContext != null) {
+      transportContext.close()
+    }
   }
 
   override def deserialize[T](deserializationAction: () => T): T = {
diff --git 
a/core/src/test/scala/org/apache/spark/ExternalShuffleServiceSuite.scala 
b/core/src/test/scala/org/apache/spark/ExternalShuffleServiceSuite.scala
index 262e2a7..8b737cd 100644
--- a/core/src/test/scala/org/apache/spark/ExternalShuffleServiceSuite.scala
+++ b/core/src/test/scala/org/apache/spark/ExternalShuffleServiceSuite.scala
@@ -24,6 +24,7 @@ import org.apache.spark.network.TransportContext
 import org.apache.spark.network.netty.SparkTransportConf
 import org.apache.spark.network.server.TransportServer
 import org.apache.spark.network.shuffle.{ExternalShuffleBlockHandler, 
ExternalShuffleClient}
+import org.apache.spark.util.Utils
 
 /**
  * This suite creates an external shuffle server and routes all shuffle 
fetches through it.
@@ -33,13 +34,14 @@ import 
org.apache.spark.network.shuffle.{ExternalShuffleBlockHandler, ExternalSh
  */
 class ExternalShuffleServiceSuite extends ShuffleSuite with BeforeAndAfterAll {
   var server: TransportServer = _
+  var transportContext: TransportContext = _
   var rpcHandler: ExternalShuffleBlockHandler = _
 
   override def beforeAll() {
     super.beforeAll()
     val transportConf = SparkTransportConf.fromSparkConf(conf, "shuffle", 
numUsableCores = 2)
     rpcHandler = new ExternalShuffleBlockHandler(transportConf, null)
-    val transportContext = new TransportContext(transportConf, rpcHandler)
+    transportContext = new TransportContext(transportConf, rpcHandler)
     server = transportContext.createServer()
 
     conf.set(config.SHUFFLE_MANAGER, "sort")
@@ -48,11 +50,16 @@ class ExternalShuffleServiceSuite extends ShuffleSuite with 
BeforeAndAfterAll {
   }
 
   override def afterAll() {
-    try {
+    Utils.tryLogNonFatalError{
       server.close()
-    } finally {
-      super.afterAll()
     }
+    Utils.tryLogNonFatalError{
+      rpcHandler.close()
+    }
+    Utils.tryLogNonFatalError{
+      transportContext.close()
+    }
+    super.afterAll()
   }
 
   // This test ensures that the external shuffle service is actually in use 
for the other tests.
diff --git a/core/src/test/scala/org/apache/spark/ThreadAudit.scala 
b/core/src/test/scala/org/apache/spark/ThreadAudit.scala
index b3cea9d..6b91162 100644
--- a/core/src/test/scala/org/apache/spark/ThreadAudit.scala
+++ b/core/src/test/scala/org/apache/spark/ThreadAudit.scala
@@ -55,18 +55,26 @@ trait ThreadAudit extends Logging {
      * creates event loops. One is wrapped inside
      * [[org.apache.spark.network.server.TransportServer]]
      * the other one is inside 
[[org.apache.spark.network.client.TransportClient]].
-     * The thread pools behind shut down asynchronously triggered by 
[[SparkContext#stop]].
-     * Manually checked and all of them stopped properly.
+     * Calling [[SparkContext#stop]] will shut down the thread pool of this 
event group
+     * asynchronously. In each case proper stopping is checked manually.
      */
     "rpc-client.*",
     "rpc-server.*",
 
     /**
+     * During [[org.apache.spark.network.TransportContext]] construction a 
separate event loop could
+     * be created for handling ChunkFetchRequest.
+     * Calling [[org.apache.spark.network.TransportContext#close]] will shut 
down the thread pool
+     * of this event group asynchronously. In each case proper stopping is 
checked manually.
+     */
+    "shuffle-chunk-fetch-handler.*",
+
+    /**
      * During [[SparkContext]] creation BlockManager creates event loops. One 
is wrapped inside
      * [[org.apache.spark.network.server.TransportServer]]
      * the other one is inside 
[[org.apache.spark.network.client.TransportClient]].
-     * The thread pools behind shut down asynchronously triggered by 
[[SparkContext#stop]].
-     * Manually checked and all of them stopped properly.
+     * Calling [[SparkContext#stop]] will shut down the thread pool of this 
event group
+     * asynchronously. In each case proper stopping is checked manually.
      */
     "shuffle-client.*",
     "shuffle-server.*"
diff --git 
a/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala 
b/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala
index 5dec4f5..115103f 100644
--- a/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala
+++ b/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala
@@ -895,6 +895,7 @@ class BlockManagerSuite extends SparkFunSuite with Matchers 
with BeforeAndAfterE
     val store = new BlockManager(SparkContext.DRIVER_IDENTIFIER, rpcEnv, 
master,
       serializerManager, conf, memoryManager, mapOutputTracker,
       shuffleManager, transfer, securityMgr, 0)
+    allStores += store
     store.initialize("app-id")
 
     // The put should fail since a1 is not serializable.
@@ -1360,74 +1361,76 @@ class BlockManagerSuite extends SparkFunSuite with 
Matchers with BeforeAndAfterE
     val tryAgainExecutor = "tryAgainExecutor"
     val succeedingExecutor = "succeedingExecutor"
 
-    // a server which delays response 50ms and must try twice for success.
-    def newShuffleServer(port: Int): (TransportServer, Int) = {
-      val failure = new Exception(tryAgainMsg)
-      val success = ByteBuffer.wrap(new Array[Byte](0))
+    val failure = new Exception(tryAgainMsg)
+    val success = ByteBuffer.wrap(new Array[Byte](0))
 
-      var secondExecutorFailedOnce = false
-      var thirdExecutorFailedOnce = false
+    var secondExecutorFailedOnce = false
+    var thirdExecutorFailedOnce = false
 
-      val handler = new NoOpRpcHandler {
-        override def receive(
-            client: TransportClient,
-            message: ByteBuffer,
-            callback: RpcResponseCallback): Unit = {
-          val msgObj = BlockTransferMessage.Decoder.fromByteBuffer(message)
-          msgObj match {
+    val handler = new NoOpRpcHandler {
+      override def receive(
+          client: TransportClient,
+          message: ByteBuffer,
+          callback: RpcResponseCallback): Unit = {
+        val msgObj = BlockTransferMessage.Decoder.fromByteBuffer(message)
+        msgObj match {
 
-            case exec: RegisterExecutor if exec.execId == timingoutExecutor =>
-              () // No reply to generate client-side timeout
+          case exec: RegisterExecutor if exec.execId == timingoutExecutor =>
+            () // No reply to generate client-side timeout
 
-            case exec: RegisterExecutor
-              if exec.execId == tryAgainExecutor && !secondExecutorFailedOnce 
=>
-              secondExecutorFailedOnce = true
-              callback.onFailure(failure)
+          case exec: RegisterExecutor
+            if exec.execId == tryAgainExecutor && !secondExecutorFailedOnce =>
+            secondExecutorFailedOnce = true
+            callback.onFailure(failure)
 
-            case exec: RegisterExecutor if exec.execId == tryAgainExecutor =>
-              callback.onSuccess(success)
+          case exec: RegisterExecutor if exec.execId == tryAgainExecutor =>
+            callback.onSuccess(success)
 
-            case exec: RegisterExecutor
-              if exec.execId == succeedingExecutor && !thirdExecutorFailedOnce 
=>
-              thirdExecutorFailedOnce = true
-              callback.onFailure(failure)
+          case exec: RegisterExecutor
+            if exec.execId == succeedingExecutor && !thirdExecutorFailedOnce =>
+            thirdExecutorFailedOnce = true
+            callback.onFailure(failure)
 
-            case exec: RegisterExecutor if exec.execId == succeedingExecutor =>
-              callback.onSuccess(success)
+          case exec: RegisterExecutor if exec.execId == succeedingExecutor =>
+            callback.onSuccess(success)
 
-          }
         }
       }
-
-      val transConf = SparkTransportConf.fromSparkConf(conf, "shuffle", 
numUsableCores = 0)
-      val transCtx = new TransportContext(transConf, handler, true)
-      (transCtx.createServer(port, 
Seq.empty[TransportServerBootstrap].asJava), port)
     }
 
-    val candidatePort = RandomUtils.nextInt(1024, 65536)
-    val (server, shufflePort) = Utils.startServiceOnPort(candidatePort,
-      newShuffleServer, conf, "ShuffleServer")
-
-    conf.set(SHUFFLE_SERVICE_ENABLED.key, "true")
-    conf.set(SHUFFLE_SERVICE_PORT.key, shufflePort.toString)
-    conf.set(SHUFFLE_REGISTRATION_TIMEOUT.key, "40")
-    conf.set(SHUFFLE_REGISTRATION_MAX_ATTEMPTS.key, "1")
-    var e = intercept[SparkException] {
-      makeBlockManager(8000, timingoutExecutor)
-    }.getMessage
-    assert(e.contains("TimeoutException"))
-
-    conf.set(SHUFFLE_REGISTRATION_TIMEOUT.key, "1000")
-    conf.set(SHUFFLE_REGISTRATION_MAX_ATTEMPTS.key, "1")
-    e = intercept[SparkException] {
-      makeBlockManager(8000, tryAgainExecutor)
-    }.getMessage
-    assert(e.contains(tryAgainMsg))
-
-    conf.set(SHUFFLE_REGISTRATION_TIMEOUT.key, "1000")
-    conf.set(SHUFFLE_REGISTRATION_MAX_ATTEMPTS.key, "2")
-    makeBlockManager(8000, succeedingExecutor)
-    server.close()
+    val transConf = SparkTransportConf.fromSparkConf(conf, "shuffle", 
numUsableCores = 0)
+
+    Utils.tryWithResource(new TransportContext(transConf, handler, true)) { 
transCtx =>
+      // a server which delays response 50ms and must try twice for success.
+      def newShuffleServer(port: Int): (TransportServer, Int) = {
+        (transCtx.createServer(port, 
Seq.empty[TransportServerBootstrap].asJava), port)
+      }
+
+      val candidatePort = RandomUtils.nextInt(1024, 65536)
+      val (server, shufflePort) = Utils.startServiceOnPort(candidatePort,
+        newShuffleServer, conf, "ShuffleServer")
+
+      conf.set(SHUFFLE_SERVICE_ENABLED.key, "true")
+      conf.set(SHUFFLE_SERVICE_PORT.key, shufflePort.toString)
+      conf.set(SHUFFLE_REGISTRATION_TIMEOUT.key, "40")
+      conf.set(SHUFFLE_REGISTRATION_MAX_ATTEMPTS.key, "1")
+      var e = intercept[SparkException] {
+        makeBlockManager(8000, timingoutExecutor)
+      }.getMessage
+      assert(e.contains("TimeoutException"))
+
+      conf.set(SHUFFLE_REGISTRATION_TIMEOUT.key, "1000")
+      conf.set(SHUFFLE_REGISTRATION_MAX_ATTEMPTS.key, "1")
+      e = intercept[SparkException] {
+        makeBlockManager(8000, tryAgainExecutor)
+      }.getMessage
+      assert(e.contains(tryAgainMsg))
+
+      conf.set(SHUFFLE_REGISTRATION_TIMEOUT.key, "1000")
+      conf.set(SHUFFLE_REGISTRATION_MAX_ATTEMPTS.key, "2")
+      makeBlockManager(8000, succeedingExecutor)
+      server.close()
+    }
   }
 
   test("fetch remote block to local disk if block size is larger than 
threshold") {


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to