This is an automated email from the ASF dual-hosted git repository. vanzin pushed a commit to branch branch-2.4 in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/branch-2.4 by this push: new 216eeec [SPARK-26604][CORE][BACKPORT-2.4] Clean up channel registration for StreamManager 216eeec is described below commit 216eeec2bd319f1d6a1228c9bc8d8a579d5e6571 Author: Liang-Chi Hsieh <vii...@gmail.com> AuthorDate: Thu Mar 7 19:48:20 2019 -0800 [SPARK-26604][CORE][BACKPORT-2.4] Clean up channel registration for StreamManager ## What changes were proposed in this pull request? This is mostly a clean backport of https://github.com/apache/spark/pull/23521 to branch-2.4 ## How was this patch tested? I've tested this with a hack in `TransportRequestHandler` to force `ChunkFetchRequest` to get dropped. Then making a number of `ExternalShuffleClient.fetchChunk` requests (which `OpenBlocks` then `ChunkFetchRequest`) and closing out of my test harness. A heap dump later reveals that the `StreamState` references are unreachable. I haven't run this through the unit test suite, but doing that now. Wanted to get this up as I think folks are waiting for it for 2.4.1 Closes #24013 from abellina/SPARK-26604_cherry_pick_2_4. Lead-authored-by: Liang-Chi Hsieh <vii...@gmail.com> Co-authored-by: Alessandro Bellina <abell...@yahoo-inc.com> Signed-off-by: Marcelo Vanzin <van...@cloudera.com> --- .../network/server/OneForOneStreamManager.java | 25 ++++++++++++---------- .../apache/spark/network/server/StreamManager.java | 10 --------- .../network/server/TransportRequestHandler.java | 1 - .../network/TransportRequestHandlerSuite.java | 9 ++++++-- .../server/OneForOneStreamManagerSuite.java | 5 +++-- .../shuffle/ExternalShuffleBlockHandler.java | 2 +- .../shuffle/ExternalShuffleBlockHandlerSuite.java | 3 ++- .../spark/network/netty/NettyBlockRpcServer.scala | 3 ++- 8 files changed, 29 insertions(+), 29 deletions(-) diff --git a/common/network-common/src/main/java/org/apache/spark/network/server/OneForOneStreamManager.java b/common/network-common/src/main/java/org/apache/spark/network/server/OneForOneStreamManager.java index 0f6a882..6fafcc1 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/server/OneForOneStreamManager.java +++ b/common/network-common/src/main/java/org/apache/spark/network/server/OneForOneStreamManager.java @@ -23,6 +23,7 @@ import java.util.Random; import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.atomic.AtomicLong; +import com.google.common.annotations.VisibleForTesting; import com.google.common.base.Preconditions; import io.netty.channel.Channel; import org.apache.commons.lang3.tuple.ImmutablePair; @@ -49,7 +50,7 @@ public class OneForOneStreamManager extends StreamManager { final Iterator<ManagedBuffer> buffers; // The channel associated to the stream - Channel associatedChannel = null; + final Channel associatedChannel; // Used to keep track of the index of the buffer that the user has retrieved, just to ensure // that the caller only requests each chunk one at a time, in order. @@ -58,9 +59,10 @@ public class OneForOneStreamManager extends StreamManager { // Used to keep track of the number of chunks being transferred and not finished yet. volatile long chunksBeingTransferred = 0L; - StreamState(String appId, Iterator<ManagedBuffer> buffers) { + StreamState(String appId, Iterator<ManagedBuffer> buffers, Channel channel) { this.appId = appId; this.buffers = Preconditions.checkNotNull(buffers); + this.associatedChannel = channel; } } @@ -72,13 +74,6 @@ public class OneForOneStreamManager extends StreamManager { } @Override - public void registerChannel(Channel channel, long streamId) { - if (streams.containsKey(streamId)) { - streams.get(streamId).associatedChannel = channel; - } - } - - @Override public ManagedBuffer getChunk(long streamId, int chunkIndex) { StreamState state = streams.get(streamId); if (chunkIndex != state.curChunk) { @@ -195,11 +190,19 @@ public class OneForOneStreamManager extends StreamManager { * * If an app ID is provided, only callers who've authenticated with the given app ID will be * allowed to fetch from this stream. + * + * This method also associates the stream with a single client connection, which is guaranteed + * to be the only reader of the stream. Once the connection is closed, the stream will never + * be used again, enabling cleanup by `connectionTerminated`. */ - public long registerStream(String appId, Iterator<ManagedBuffer> buffers) { + public long registerStream(String appId, Iterator<ManagedBuffer> buffers, Channel channel) { long myStreamId = nextStreamId.getAndIncrement(); - streams.put(myStreamId, new StreamState(appId, buffers)); + streams.put(myStreamId, new StreamState(appId, buffers, channel)); return myStreamId; } + @VisibleForTesting + public int numStreamStates() { + return streams.size(); + } } diff --git a/common/network-common/src/main/java/org/apache/spark/network/server/StreamManager.java b/common/network-common/src/main/java/org/apache/spark/network/server/StreamManager.java index c535295..e48d27b 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/server/StreamManager.java +++ b/common/network-common/src/main/java/org/apache/spark/network/server/StreamManager.java @@ -61,16 +61,6 @@ public abstract class StreamManager { } /** - * Associates a stream with a single client connection, which is guaranteed to be the only reader - * of the stream. The getChunk() method will be called serially on this connection and once the - * connection is closed, the stream will never be used again, enabling cleanup. - * - * This must be called before the first getChunk() on the stream, but it may be invoked multiple - * times with the same channel and stream id. - */ - public void registerChannel(Channel channel, long streamId) { } - - /** * Indicates that the given channel has been terminated. After this occurs, we are guaranteed not * to read from the associated streams again, so any state can be cleaned up. */ diff --git a/common/network-common/src/main/java/org/apache/spark/network/server/TransportRequestHandler.java b/common/network-common/src/main/java/org/apache/spark/network/server/TransportRequestHandler.java index 9fac96d..77a194b 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/server/TransportRequestHandler.java +++ b/common/network-common/src/main/java/org/apache/spark/network/server/TransportRequestHandler.java @@ -127,7 +127,6 @@ public class TransportRequestHandler extends MessageHandler<RequestMessage> { ManagedBuffer buf; try { streamManager.checkAuthorization(reverseClient, req.streamChunkId.streamId); - streamManager.registerChannel(channel, req.streamChunkId.streamId); buf = streamManager.getChunk(req.streamChunkId.streamId, req.streamChunkId.chunkIndex); } catch (Exception e) { logger.error(String.format("Error opening block %s for request from %s", diff --git a/common/network-common/src/test/java/org/apache/spark/network/TransportRequestHandlerSuite.java b/common/network-common/src/test/java/org/apache/spark/network/TransportRequestHandlerSuite.java index 2656cbe..0b565f2 100644 --- a/common/network-common/src/test/java/org/apache/spark/network/TransportRequestHandlerSuite.java +++ b/common/network-common/src/test/java/org/apache/spark/network/TransportRequestHandlerSuite.java @@ -62,8 +62,10 @@ public class TransportRequestHandlerSuite { managedBuffers.add(new TestManagedBuffer(20)); managedBuffers.add(new TestManagedBuffer(30)); managedBuffers.add(new TestManagedBuffer(40)); - long streamId = streamManager.registerStream("test-app", managedBuffers.iterator()); - streamManager.registerChannel(channel, streamId); + long streamId = streamManager.registerStream("test-app", managedBuffers.iterator(), channel); + + assert streamManager.numStreamStates() == 1; + TransportClient reverseClient = mock(TransportClient.class); TransportRequestHandler requestHandler = new TransportRequestHandler(channel, reverseClient, rpcHandler, 2L); @@ -98,6 +100,9 @@ public class TransportRequestHandlerSuite { requestHandler.handle(request3); verify(channel, times(1)).close(); assert responseAndPromisePairs.size() == 3; + + streamManager.connectionTerminated(channel); + assert streamManager.numStreamStates() == 0; } private class ExtendedChannelPromise extends DefaultChannelPromise { diff --git a/common/network-common/src/test/java/org/apache/spark/network/server/OneForOneStreamManagerSuite.java b/common/network-common/src/test/java/org/apache/spark/network/server/OneForOneStreamManagerSuite.java index c647525..4248762 100644 --- a/common/network-common/src/test/java/org/apache/spark/network/server/OneForOneStreamManagerSuite.java +++ b/common/network-common/src/test/java/org/apache/spark/network/server/OneForOneStreamManagerSuite.java @@ -37,14 +37,15 @@ public class OneForOneStreamManagerSuite { TestManagedBuffer buffer2 = Mockito.spy(new TestManagedBuffer(20)); buffers.add(buffer1); buffers.add(buffer2); - long streamId = manager.registerStream("appId", buffers.iterator()); Channel dummyChannel = Mockito.mock(Channel.class, Mockito.RETURNS_SMART_NULLS); - manager.registerChannel(dummyChannel, streamId); + manager.registerStream("appId", buffers.iterator(), dummyChannel); + assert manager.numStreamStates() == 1; manager.connectionTerminated(dummyChannel); Mockito.verify(buffer1, Mockito.times(1)).release(); Mockito.verify(buffer2, Mockito.times(1)).release(); + assert manager.numStreamStates() == 0; } } diff --git a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockHandler.java b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockHandler.java index 098fa79..732b920 100644 --- a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockHandler.java +++ b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockHandler.java @@ -91,7 +91,7 @@ public class ExternalShuffleBlockHandler extends RpcHandler { OpenBlocks msg = (OpenBlocks) msgObj; checkAuth(client, msg.appId); long streamId = streamManager.registerStream(client.getClientId(), - new ManagedBufferIterator(msg.appId, msg.execId, msg.blockIds)); + new ManagedBufferIterator(msg.appId, msg.execId, msg.blockIds), client.getChannel()); if (logger.isTraceEnabled()) { logger.trace("Registered streamId {} with {} buffers for client {} from host {}", streamId, diff --git a/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleBlockHandlerSuite.java b/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleBlockHandlerSuite.java index 7846b71..1e4eda0 100644 --- a/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleBlockHandlerSuite.java +++ b/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleBlockHandlerSuite.java @@ -101,7 +101,8 @@ public class ExternalShuffleBlockHandlerSuite { @SuppressWarnings("unchecked") ArgumentCaptor<Iterator<ManagedBuffer>> stream = (ArgumentCaptor<Iterator<ManagedBuffer>>) (ArgumentCaptor<?>) ArgumentCaptor.forClass(Iterator.class); - verify(streamManager, times(1)).registerStream(anyString(), stream.capture()); + verify(streamManager, times(1)).registerStream(anyString(), stream.capture(), + any()); Iterator<ManagedBuffer> buffers = stream.getValue(); assertEquals(block0Marker, buffers.next()); assertEquals(block1Marker, buffers.next()); diff --git a/core/src/main/scala/org/apache/spark/network/netty/NettyBlockRpcServer.scala b/core/src/main/scala/org/apache/spark/network/netty/NettyBlockRpcServer.scala index 7076701..27f4f94 100644 --- a/core/src/main/scala/org/apache/spark/network/netty/NettyBlockRpcServer.scala +++ b/core/src/main/scala/org/apache/spark/network/netty/NettyBlockRpcServer.scala @@ -59,7 +59,8 @@ class NettyBlockRpcServer( val blocksNum = openBlocks.blockIds.length val blocks = for (i <- (0 until blocksNum).view) yield blockManager.getBlockData(BlockId.apply(openBlocks.blockIds(i))) - val streamId = streamManager.registerStream(appId, blocks.iterator.asJava) + val streamId = streamManager.registerStream(appId, blocks.iterator.asJava, + client.getChannel) logTrace(s"Registered streamId $streamId with $blocksNum buffers") responseContext.onSuccess(new StreamHandle(streamId, blocksNum).toByteBuffer) --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org